diff --git a/prismaudio_core/models/__init__.py b/prismaudio_core/models/__init__.py new file mode 100644 index 0000000..aed753a --- /dev/null +++ b/prismaudio_core/models/__init__.py @@ -0,0 +1,9 @@ +""" +PrismAudio model modules for inference. + +Re-exports create_model_from_config from the factory module. +""" + +from prismaudio_core.factory import create_model_from_config + +__all__ = ["create_model_from_config"] diff --git a/prismaudio_core/models/adp.py b/prismaudio_core/models/adp.py new file mode 100644 index 0000000..49eb526 --- /dev/null +++ b/prismaudio_core/models/adp.py @@ -0,0 +1,1588 @@ +# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License +# License can be found in LICENSES/LICENSE_ADP.txt + +import math +from inspect import isfunction +from math import ceil, floor, log, pi, log2 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from packaging import version + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum +from torch.backends.cuda import sdp_kernel +from torch.nn import functional as F +from dac.nn.layers import Snake1d + +""" +Utils +""" + + +class ConditionedSequential(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None): + for module in self.module_list: + x = module(x, mapping) + return x + +T = TypeVar("T") + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val: Optional[T]) -> T: + return val is not None + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2 ** z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + +""" +Convolutional Blocks +""" +import typing as tp + +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + dilation = self.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding)) + return super().forward(x) + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + padding_total = kernel_size - stride + + y = super().forward(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if causal: + padding_right = ceil(padding_total) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +def Downsample1d( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor + ) + + +def Upsample1d( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 + ), + ) + else: + return ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor + ) + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + num_groups: int = 8, + use_norm: bool = True, + use_snake: bool = False + ) -> None: + super().__init__() + + self.groupnorm = ( + nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) + if use_norm + else nn.Identity() + ) + + if use_snake: + self.activation = Snake1d(in_channels) + else: + self.activation = nn.SiLU() + + self.project = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def forward( + self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False + ) -> Tensor: + x = self.groupnorm(x) + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + x = self.activation(x) + return self.project(x, causal=causal) + + +class MappingToScaleShift(nn.Module): + def __init__( + self, + features: int, + channels: int, + ): + super().__init__() + + self.to_scale_shift = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=features, out_features=channels * 2), + ) + + def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: + scale_shift = self.to_scale_shift(mapping) + scale_shift = rearrange(scale_shift, "b c -> b c 1") + scale, shift = scale_shift.chunk(2, dim=1) + return scale, shift + + +class ResnetBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + use_norm: bool = True, + use_snake: bool = False, + num_groups: int = 8, + context_mapping_features: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_mapping = exists(context_mapping_features) + + self.block1 = ConvBlock1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + if self.use_mapping: + assert exists(context_mapping_features) + self.to_scale_shift = MappingToScaleShift( + features=context_mapping_features, channels=out_channels + ) + + self.block2 = ConvBlock1d( + in_channels=out_channels, + out_channels=out_channels, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + self.to_out = ( + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + assert_message = "context mapping required if context_mapping_features > 0" + assert not (self.use_mapping ^ exists(mapping)), assert_message + + h = self.block1(x, causal=causal) + + scale_shift = None + if self.use_mapping: + scale_shift = self.to_scale_shift(mapping) + + h = self.block2(h, scale_shift=scale_shift, causal=causal) + + return h + self.to_out(x) + + +class Patcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + assert_message = f"out_channels must be divisible by patch_size ({patch_size})" + assert out_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels, + out_channels=out_channels // patch_size, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.block(x, mapping, causal=causal) + x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) + return x + + +class Unpatcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False + ): + super().__init__() + assert_message = f"in_channels must be divisible by patch_size ({patch_size})" + assert in_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels // patch_size, + out_channels=out_channels, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) + x = self.block(x, mapping, causal=causal) + return x + + +""" +Attention Components +""" +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +def add_mask(sim: Tensor, mask: Tensor) -> Tensor: + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + +def causal_mask(q: Tensor, k: Tensor) -> Tensor: + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) + mask = repeat(mask, "n m -> b n m", b=b) + return mask + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + mid_features = head_features * num_heads + out_features = default(out_features, features) + + self.to_out = nn.Linear( + in_features=mid_features, out_features=out_features + ) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False + ) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + + if not self.use_flash: + if is_causal and not mask: + # Mask out future tokens for causal attention + mask = causal_mask(q, k) + + # Compute similarity matrix and add eventual mask + sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale + sim = add_mask(sim, mask) if exists(mask) else sim + + # Get attention matrix with softmax + attn = sim.softmax(dim=-1, dtype=torch.float32) + + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + else: + with sdp_kernel(*self.sdp_kernel_config): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + causal: bool = False, + ): + super().__init__() + self.context_features = context_features + self.causal = causal + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + out_features=out_features, + ) + + def forward( + self, + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] + context_mask: Optional[Tensor] = None, # [b, m], false is masked, + causal: Optional[bool] = False, + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + + if exists(context_mask): + # Mask out cross-attention for padding tokens + mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) + k, v = k * mask, v * mask + + # Compute and return attention + return self.attention(q, k, v, is_causal=self.causal or causal) + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + x = self.attention(x, causal=causal) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context, context_mask=context_mask) + x + x = self.feed_forward(x) + x + return x + + +""" +Transformers +""" + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.to_in = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + Rearrange("b c t -> b t c"), + ) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + context_features=context_features, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.to_in(x) + for block in self.blocks: + x = block(x, context=context, context_mask=context_mask, causal=causal) + x = self.to_out(x) + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +""" +Encoder/Decoder Components +""" + + +class DownsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_groups: int, + num_layers: int, + kernel_multiplier: int = 2, + use_pre_downsample: bool = True, + use_skip: bool = False, + use_snake: bool = False, + extract_channels: int = 0, + context_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + self.use_pre_downsample = use_pre_downsample + self.use_skip = use_skip + self.use_transformer = num_transformer_blocks > 0 + self.use_extract = extract_channels > 0 + self.use_context = context_channels > 0 + + channels = out_channels if use_pre_downsample else in_channels + + self.downsample = Downsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + kernel_multiplier=kernel_multiplier, + ) + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + context_channels if i == 0 else channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for i in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + channels: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: + + if self.use_pre_downsample: + x = self.downsample(x) + + if self.use_context and exists(channels): + x = torch.cat([x, channels], dim=1) + + skips = [] + for block in self.blocks: + x = block(x, mapping=mapping, causal=causal) + skips += [x] if self.use_skip else [] + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + skips += [x] if self.use_skip else [] + + if not self.use_pre_downsample: + x = self.downsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return (x, skips) if self.use_skip else x + + +class UpsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_layers: int, + num_groups: int, + use_nearest: bool = False, + use_pre_upsample: bool = False, + use_skip: bool = False, + use_snake: bool = False, + skip_channels: int = 0, + use_skip_scale: bool = False, + extract_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + + self.use_extract = extract_channels > 0 + self.use_pre_upsample = use_pre_upsample + self.use_transformer = num_transformer_blocks > 0 + self.use_skip = use_skip + self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + + channels = out_channels if use_pre_upsample else in_channels + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + skip_channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for _ in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.upsample = Upsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + use_nearest=use_nearest, + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: + return torch.cat([x, skip * self.skip_scale], dim=1) + + def forward( + self, + x: Tensor, + *, + skips: Optional[List[Tensor]] = None, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + + if self.use_pre_upsample: + x = self.upsample(x) + + for block in self.blocks: + x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x + x = block(x, mapping=mapping, causal=causal) + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + + if not self.use_pre_upsample: + x = self.upsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return x + + +class BottleneckBlock1d(nn.Module): + def __init__( + self, + channels: int, + *, + num_groups: int, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + self.use_transformer = num_transformer_blocks > 0 + + self.pre_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.post_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Tensor: + x = self.pre_block(x, mapping=mapping, causal=causal) + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + x = self.post_block(x, mapping=mapping, causal=causal) + return x + + +""" +UNet +""" + + +class UNet1d(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + multipliers: Sequence[int], + factors: Sequence[int], + num_blocks: Sequence[int], + attentions: Sequence[int], + patch_size: int = 1, + resnet_groups: int = 8, + use_context_time: bool = True, + kernel_multiplier_downsample: int = 2, + use_nearest_upsample: bool = False, + use_skip_scale: bool = True, + use_snake: bool = False, + use_stft: bool = False, + use_stft_context: bool = False, + out_channels: Optional[int] = None, + context_features: Optional[int] = None, + context_features_multiplier: int = 4, + context_channels: Optional[Sequence[int]] = None, + context_embedding_features: Optional[int] = None, + **kwargs, + ): + super().__init__() + out_channels = default(out_channels, in_channels) + context_channels = list(default(context_channels, [])) + num_layers = len(multipliers) - 1 + use_context_features = exists(context_features) + use_context_channels = len(context_channels) > 0 + context_mapping_features = None + + attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) + + self.num_layers = num_layers + self.use_context_time = use_context_time + self.use_context_features = use_context_features + self.use_context_channels = use_context_channels + self.use_stft = use_stft + self.use_stft_context = use_stft_context + + self.context_features = context_features + context_channels_pad_length = num_layers + 1 - len(context_channels) + context_channels = context_channels + [0] * context_channels_pad_length + self.context_channels = context_channels + self.context_embedding_features = context_embedding_features + + if use_context_channels: + has_context = [c > 0 for c in context_channels] + self.has_context = has_context + self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + + assert ( + len(factors) == num_layers + and len(attentions) >= num_layers + and len(num_blocks) == num_layers + ) + + if use_context_time or use_context_features: + context_mapping_features = channels * context_features_multiplier + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_stft: + stft_kwargs, kwargs = groupby("stft_", kwargs) + assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" + stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 + in_channels *= stft_channels + out_channels *= stft_channels + context_channels[0] *= stft_channels if use_stft_context else 1 + assert exists(in_channels) and exists(out_channels) + self.stft = STFT(**stft_kwargs) + + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + + self.to_in = Patcher( + in_channels=in_channels + context_channels[0], + out_channels=channels * multipliers[0], + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + self.downsamples = nn.ModuleList( + [ + DownsampleBlock1d( + in_channels=channels * multipliers[i], + out_channels=channels * multipliers[i + 1], + context_mapping_features=context_mapping_features, + context_channels=context_channels[i + 1], + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i], + factor=factors[i], + kernel_multiplier=kernel_multiplier_downsample, + num_groups=resnet_groups, + use_pre_downsample=True, + use_skip=True, + use_snake=use_snake, + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in range(num_layers) + ] + ) + + self.bottleneck = BottleneckBlock1d( + channels=channels * multipliers[-1], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_groups=resnet_groups, + num_transformer_blocks=attentions[-1], + use_snake=use_snake, + **attention_kwargs, + ) + + self.upsamples = nn.ModuleList( + [ + UpsampleBlock1d( + in_channels=channels * multipliers[i + 1], + out_channels=channels * multipliers[i], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i] + (1 if attentions[i] else 0), + factor=factors[i], + use_nearest=use_nearest_upsample, + num_groups=resnet_groups, + use_skip_scale=use_skip_scale, + use_pre_upsample=False, + use_skip=True, + use_snake=use_snake, + skip_channels=channels * multipliers[i + 1], + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in reversed(range(num_layers)) + ] + ) + + self.to_out = Unpatcher( + in_channels=channels * multipliers[0], + out_channels=out_channels, + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def get_channels( + self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 + ) -> Optional[Tensor]: + """Gets context channels at `layer` and checks that shape is correct""" + use_context_channels = self.use_context_channels and self.has_context[layer] + if not use_context_channels: + return None + assert exists(channels_list), "Missing context" + # Get channels index (skipping zero channel contexts) + channels_id = self.channels_ids[layer] + # Get channels + channels = channels_list[channels_id] + message = f"Missing context for layer {layer} at index {channels_id}" + assert exists(channels), message + # Check channels + num_channels = self.context_channels[layer] + message = f"Expected context with {num_channels} channels at idx {channels_id}" + assert channels.shape[1] == num_channels, message + # STFT channels if requested + channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa + return channels + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + return mapping + + def forward( + self, + x: Tensor, + time: Optional[Tensor] = None, + *, + features: Optional[Tensor] = None, + channels_list: Optional[Sequence[Tensor]] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: + channels = self.get_channels(channels_list, layer=0) + # Apply stft if required + x = self.stft.encode1d(x) if self.use_stft else x # type: ignore + # Concat context channels at layer 0 if provided + x = torch.cat([x, channels], dim=1) if exists(channels) else x + # Compute mapping from time and features + mapping = self.get_mapping(time, features) + x = self.to_in(x, mapping, causal=causal) + skips_list = [x] + + for i, downsample in enumerate(self.downsamples): + channels = self.get_channels(channels_list, layer=i + 1) + x, skips = downsample( + x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) + skips_list += [skips] + + x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + for i, upsample in enumerate(self.upsamples): + skips = skips_list.pop() + x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + x += skips_list.pop() + x = self.to_out(x, mapping, causal=causal) + x = self.stft.decode1d(x) if self.use_stft else x + + return x + + +""" Conditioning Modules """ + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding + + +def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" + + def __init__( + self, + context_embedding_max_length: int, + context_embedding_features: int, + use_xattn_time: bool = False, + **kwargs, + ): + super().__init__( + context_embedding_features=context_embedding_features, **kwargs + ) + + self.use_xattn_time = use_xattn_time + + if use_xattn_time: + assert exists(context_embedding_features) + self.to_time_embedding = nn.Sequential( + TimePositionalEmbedding( + dim=kwargs["channels"], out_features=context_embedding_features + ), + nn.GELU(), + ) + + context_embedding_max_length += 1 # Add one for time embedding + + self.fixed_embedding = FixedEmbedding( + max_length=context_embedding_max_length, features=context_embedding_features + ) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + embedding: Tensor, + embedding_mask: Optional[Tensor] = None, + embedding_scale: float = 1.0, + embedding_mask_proba: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + scale_phi: float = 0.4, + negative_embedding: Optional[Tensor] = None, + negative_embedding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + + if self.use_xattn_time: + embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) + + if embedding_mask is not None: + embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) + + fixed_embedding = self.fixed_embedding(embedding) + + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + if batch_cfg: + batch_x = torch.cat([x, x], dim=0) + batch_time = torch.cat([time, time], dim=0) + + if negative_embedding is not None: + if negative_embedding_mask is not None: + negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) + + negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) + + batch_embed = torch.cat([embedding, negative_embedding], dim=0) + + else: + batch_embed = torch.cat([embedding, fixed_embedding], dim=0) + + batch_mask = None + if embedding_mask is not None: + batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) + + batch_features = None + features = kwargs.pop("features", None) + if self.use_context_features: + batch_features = torch.cat([features, features], dim=0) + + batch_channels = None + channels_list = kwargs.pop("channels_list", None) + if self.use_context_channels: + batch_channels = [] + for channels in channels_list: + batch_channels += [torch.cat([channels, channels], dim=0)] + + # Compute both normal and fixed embedding outputs + batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + out, out_masked = batch_out.chunk(2, dim=0) + + else: + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + + out_cfg = out_masked + (out - out_masked) * embedding_scale + + if rescale_cfg: + + out_std = out.std(dim=1, keepdim=True) + out_cfg_std = out_cfg.std(dim=1, keepdim=True) + + return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + + else: + + return out_cfg + + else: + return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: + x = x if torch.is_tensor(x) else torch.tensor(x) + return x.expand(shape) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: Union[ + bool, Sequence[bool], Sequence[Sequence[bool]], Tensor + ] = False, + channels_scale: Union[ + float, Sequence[float], Sequence[Sequence[float]], Tensor + ] = 0, + **kwargs, + ) -> Tensor: + b, n = x.shape[0], len(channels_list) + channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) + channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) + + # Augmentation (for each channel list item) + for i in range(n): + scale = channels_scale[:, i] * channels_augmentation[:, i] + scale = rearrange(scale, "b -> b 1 1") + item = channels_list[i] + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + + # Scale embedding (sum reduction if more than one channel list item) + channels_scale_emb = self.embedder(channels_scale) + channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=channels_scale_emb, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) + + +def XUNet1d(type: str = "base", **kwargs) -> UNet1d: + if type == "base": + return UNet1d(**kwargs) + elif type == "all": + return UNetAll1d(**kwargs) + elif type == "cfg": + return UNetCFG1d(**kwargs) + elif type == "ncca": + return UNetNCCA1d(**kwargs) + else: + raise ValueError(f"Unknown XUNet1d type: {type}") + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +""" +Audio Transforms +""" + + +class STFT(nn.Module): + """Helper for torch stft and istft""" + + def __init__( + self, + num_fft: int = 1023, + hop_length: int = 256, + window_length: Optional[int] = None, + length: Optional[int] = None, + use_complex: bool = False, + ): + super().__init__() + self.num_fft = num_fft + self.hop_length = default(hop_length, floor(num_fft // 4)) + self.window_length = default(window_length, num_fft) + self.length = length + self.register_buffer("window", torch.hann_window(self.window_length)) + self.use_complex = use_complex + + def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: + b = wave.shape[0] + wave = rearrange(wave, "b c t -> (b c) t") + + stft = torch.stft( + wave, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + return_complex=True, + normalized=True, + ) + + if self.use_complex: + # Returns real and imaginary + stft_a, stft_b = stft.real, stft.imag + else: + # Returns magnitude and phase matrices + magnitude, phase = torch.abs(stft), torch.angle(stft) + stft_a, stft_b = magnitude, phase + + return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + + def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: + b, l = stft_a.shape[0], stft_a.shape[-1] # noqa + length = closest_power_2(l * self.hop_length) + + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + + if self.use_complex: + real, imag = stft_a, stft_b + else: + magnitude, phase = stft_a, stft_b + real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) + + stft = torch.stack([real, imag], dim=-1) + + wave = torch.istft( + stft, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + length=default(self.length, length), + normalized=True, + ) + + return rearrange(wave, "(b c) t -> b c t", b=b) + + def encode1d( + self, wave: Tensor, stacked: bool = True + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + stft_a, stft_b = self.encode(wave) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) + + def decode1d(self, stft_pair: Tensor) -> Tensor: + f = self.num_fft // 2 + 1 + stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + return self.decode(stft_a, stft_b) diff --git a/prismaudio_core/models/autoencoders.py b/prismaudio_core/models/autoencoders.py new file mode 100644 index 0000000..b279823 --- /dev/null +++ b/prismaudio_core/models/autoencoders.py @@ -0,0 +1,830 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from dac.nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper +from .pretransforms import Pretransform + + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + """Minimal stub for inference.utils.prepare_audio used by autoencoders.""" + import torchaudio.transforms as T + import torch + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + if audio.shape[0] > target_channels: + audio = audio[:target_channels] + elif audio.shape[0] < target_channels: + audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels] + + if audio.shape[-1] < target_length: + audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1])) + elif audio.shape[-1] > target_length: + audio = audio[..., :target_length] + + return audio.unsqueeze(0) + + +def _lazy_create_pretransform_from_config(pretransform, sample_rate): + from prismaudio_core.factory import create_pretransform_from_config + return create_pretransform_from_config(pretransform, sample_rate) + + +def _lazy_create_bottleneck_from_config(bottleneck): + from prismaudio_core.factory import create_bottleneck_from_config + return create_bottleneck_from_config(bottleneck) + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + # import ipdb + # ipdb.set_trace() + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + latents = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + # import ipdb + # ipdb.set_trace() + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + print(f'audio shape: {audio.shape}') + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + print(f'y_final shape: {y_final.shape}') + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + print(f'y_chunk shape: {y_chunk.shape}') + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +class DiffusionAutoencoder(AudioAutoencoder): + def __init__( + self, + diffusion: ConditionedDiffusionModel, + diffusion_downsampling_ratio, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.diffusion = diffusion + + self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio + + if self.encoder is not None: + # Shrink the initial encoder parameters to avoid saturated latents + with torch.no_grad(): + for param in self.encoder.parameters(): + param *= 0.5 + + def decode(self, latents, steps=100): + + upsampled_length = latents.shape[2] * self.downsampling_ratio + + if self.bottleneck is not None: + latents = self.bottleneck.decode(latents) + + if self.decoder is not None: + latents = self.decode(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != upsampled_length: + latents = F.interpolate(latents, size=upsampled_length, mode='nearest') + + noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) + from prismaudio_core.inference.sampling import sample + decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + decoded = self.pretransform.decode(decoded) + + return decoded + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = _lazy_create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) + +def create_diffAE_from_config(config: Dict[str, Any]): + + diffae_config = config["model"] + + if "encoder" in diffae_config: + encoder = create_encoder_from_config(diffae_config["encoder"]) + else: + encoder = None + + if "decoder" in diffae_config: + decoder = create_decoder_from_config(diffae_config["decoder"]) + else: + decoder = None + + diffusion_model_type = diffae_config["diffusion"]["type"] + + if diffusion_model_type == "DAU1d": + diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "adp_1d": + diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "dit": + diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) + + latent_dim = diffae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = diffae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = diffae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + bottleneck = diffae_config.get("bottleneck", None) + + pretransform = diffae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = _lazy_create_bottleneck_from_config(bottleneck) + + diffusion_downsampling_ratio = None, + + if diffusion_model_type == "DAU1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) + elif diffusion_model_type == "adp_1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) + elif diffusion_model_type == "dit": + diffusion_downsampling_ratio = 1 + + return DiffusionAutoencoder( + encoder=encoder, + decoder=decoder, + diffusion=diffusion, + io_channels=io_channels, + sample_rate=sample_rate, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + diffusion_downsampling_ratio=diffusion_downsampling_ratio, + bottleneck=bottleneck, + pretransform=pretransform + ) diff --git a/prismaudio_core/models/blocks.py b/prismaudio_core/models/blocks.py new file mode 100644 index 0000000..3c827fd --- /dev/null +++ b/prismaudio_core/models/blocks.py @@ -0,0 +1,339 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from dac.nn.layers import Snake1d + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/prismaudio_core/models/bottleneck.py b/prismaudio_core/models/bottleneck.py new file mode 100644 index 0000000..5e81cab --- /dev/null +++ b/prismaudio_core/models/bottleneck.py @@ -0,0 +1,355 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + self.bypass_mmd = bypass_mmd + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + if self.bypass_mmd: + mmd = torch.tensor(0.0) + else: + mmd = compute_mmd(x) + + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, noise_augment_dim=0, **kwargs): + super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") + + self.noise_augment_dim = noise_augment_dim + + self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) + + def encode(self, x, return_info=False): + info = {} + + orig_dtype = x.dtype + x = x.float() + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + x = x.to(orig_dtype) + + # Reorder indices to match the expected format + indices = rearrange(indices, "b n q -> b q n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/prismaudio_core/models/conditioners.py b/prismaudio_core/models/conditioners.py new file mode 100644 index 0000000..8a5e7ee --- /dev/null +++ b/prismaudio_core/models/conditioners.py @@ -0,0 +1,1100 @@ +#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py + +import torch +import logging, warnings +import string +import typing as tp +import gc +from typing import Literal, Optional +import os +from .adp import NumberEmbedder +from .pretransforms import Pretransform +from .utils import load_ckpt_state_dict + + +# Stub for training utility - only needed for load_state_dict, not inference +def copy_state_dict(model, state_dict): + """Stub replacement for PrismAudio.training.utils.copy_state_dict""" + model.load_state_dict(state_dict, strict=False) + + +def set_audio_channels(audio, target_channels): + """Stub replacement for PrismAudio.inference.utils.set_audio_channels""" + if audio.shape[1] == target_channels: + return audio + if target_channels == 1: + return audio.mean(dim=1, keepdim=True) + if target_channels == 2 and audio.shape[1] == 1: + return audio.repeat(1, 2, 1) + raise ValueError(f"Cannot convert {audio.shape[1]} channels to {target_channels}") +import numpy as np +from einops import rearrange +from transformers import AutoProcessor, AutoModel +from torch import nn +import torch.nn.functional as F +from .mmmodules.model.low_level import ConvMLP, MLP +from torch.nn.utils.rnn import pad_sequence + +class Conditioner(nn.Module): + def __init__( + self, + dim: int, + output_dim: int, + project_out: bool = False + ): + + super().__init__() + + self.dim = dim + self.output_dim = output_dim + self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() + + def forward(self, x: tp.Any) -> tp.Any: + raise NotImplementedError() + +class Cond_MLP(Conditioner): + def __init__(self, dim, output_dim, dropout = 0.0): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.dropout = dropout + def forward(self, x, device: tp.Any = "cuda"): + x = pad_sequence(x, batch_first=True).to(device) + # x = torch.stack(x, dim=0).to(device) + + if self.dropout > 0.0: + if self.training: + null_embed = torch.zeros_like(x, device=device) + dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool) + x = torch.where(dropout_mask, null_embed, x) + elif x.shape[0] < 16: # default test batch size=1 + null_embed = torch.zeros_like(x, device=device) + x = torch.cat([x, null_embed], dim=0) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Global_MLP(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + x = x.mean(dim=1) + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Cond_MLP_1(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim), + nn.SiLU(), + MLP(output_dim, output_dim * 4), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Cond_MLP_Global(Conditioner): + def __init__(self, dim, output_dim, dropout = 0.0): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.global_embedder = nn.Sequential( + nn.Linear(output_dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.dropout = dropout + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + if self.dropout > 0 and self.training: + null_embed = torch.zeros_like(x, device=device) + dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool) + x = torch.where(dropout_mask, null_embed, x) + x = self.embedder(x) # B x 117 x C + global_x = self.global_embedder(x[:,0,:]) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Cond_MLP_Global_1(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim), + nn.SiLU(), + MLP(output_dim, output_dim * 4), + ) + self.global_embedder = nn.Sequential( + nn.Linear(dim, output_dim), + MLP(output_dim, output_dim * 4), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_embedder(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Cond_MLP_Global_2(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.global_embedder = nn.Sequential( + nn.Linear(output_dim, output_dim, bias=False), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_embedder(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Sync_MLP(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, dim))) + nn.init.constant_(self.sync_pos_emb, 0) + def forward(self, x, device: tp.Any = "cuda"): + sync_f = torch.stack(x, dim=0).to(device) + bs, length, dim = sync_f.shape + #print(sync_f.shape,flush=True) + # B * num_segments (24) * 8 * 768 + num_sync_segments = length // 8 + sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb + sync_f = sync_f.flatten(1, 2) # (B, VN, D) + x = self.embedder(sync_f) # B x 117 x C + x = x.transpose(1,2) + x = F.interpolate(x, ((int)(194*sync_f.shape[1]/216), ), mode='linear', align_corners=False) + x = x.transpose(1,2) + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Cond_ConvMLP(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim), + nn.SiLU(), + ConvMLP(output_dim, output_dim * 4, kernel_size=1, padding=0), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Video_Global(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim, global_dim=1536): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + data = torch.load(path) + video_feats.append(data['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_proj(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Video_Sync(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['sync_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Text_Linear(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_text_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class mm_unchang(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + return [x] + +class CLIPConditioner(Conditioner): + + CLIP_MODELS = ["metaclip-base", "metaclip-b16", "metaclip-large", "metaclip-huge"] + + CLIP_MODEL_DIMS = { + "metaclip-base": 512, + "metaclip-b16": 512, + "metaclip-large": 768, + "metaclip-huge": 1024, + } + + def __init__( + self, + dim: int, + output_dim: int, + clip_model_name: str = "metaclip-huge", + enable_grad: bool = False, + project_out: bool = False + ): + assert clip_model_name in self.CLIP_MODELS, f"Unknown CLIP model name: {clip_model_name}" + super().__init__(self.CLIP_MODEL_DIMS[clip_model_name], output_dim, project_out=project_out) + + self.enable_grad = enable_grad + model = AutoModel.from_pretrained(f"useful_ckpts/{clip_model_name}").train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + + + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, images: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + # import ipdb + # ipdb.set_trace() + + self.model.eval() + if not isinstance(images[0], torch.Tensor): + video_feats = [] + for path in images: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + images = torch.stack(video_feats, dim=0).to(device) + else: + images = torch.stack(images, dim=0).to(device) + bsz, t, c, h, w = images.shape + # 使用 rearrange 进行维度合并 + images = rearrange(images, 'b t c h w -> (b t) c h w') + with torch.set_grad_enabled(self.enable_grad): + image_features = self.model.get_image_features(images) + image_features = rearrange(image_features, '(b t) d -> b t d', b=bsz, t=t) + image_features = self.proj_out(image_features) + + + return [image_features, torch.ones(image_features.shape[0], 1).to(device)] + +class IntConditioner(Conditioner): + def __init__(self, + output_dim: int, + min_val: int=0, + max_val: int=512 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) + + def forward(self, ints: tp.List[int], device=None) -> tp.Any: + + #self.int_embedder.to(device) + + ints = torch.tensor(ints).to(device) + ints = ints.clamp(self.min_val, self.max_val) + + int_embeds = self.int_embedder(ints).unsqueeze(1) + + return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + +class NumberConditioner(Conditioner): + ''' + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + ''' + def __init__(self, + output_dim: int, + min_val: float=0, + max_val: float=1 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + + self.embedder = NumberEmbedder(features=output_dim) + + def forward(self, floats: tp.List[float], device=None) -> tp.Any: + + # Cast the inputs to floats + floats = [float(x) for x in floats] + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] + +class CLAPTextConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + use_text_features = False, + feature_layer_ix: int = -1, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + finetune: bool = False): + super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) + + self.use_text_features = use_text_features + self.feature_layer_ix = feature_layer_ix + self.finetune = finetune + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.text_branch.requires_grad_(True) + self.model.model.text_branch.train() + else: + self.model.model.text_branch.requires_grad_(False) + self.model.model.text_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.audio_branch + + gc.collect() + torch.cuda.empty_cache() + + def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): + prompt_tokens = self.model.tokenizer(prompts) + attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) + prompt_features = self.model.model.text_branch( + input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), + attention_mask=attention_mask, + output_hidden_states=True + )["hidden_states"][layer_ix] + + return prompt_features, attention_mask + + def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + self.model.to(device) + + if self.use_text_features: + if len(texts) == 1: + text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) + text_features = text_features[:1, ...] + text_attention_mask = text_attention_mask[:1, ...] + else: + text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) + return [self.proj_out(text_features), text_attention_mask] + + # Fix for CLAP bug when only one text is passed + if len(texts) == 1: + text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] + else: + text_embedding = self.model.get_text_embedding(texts, use_tensor=True) + + text_embedding = text_embedding.unsqueeze(1).to(device) + + return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] + +class CLAPAudioConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False): + super().__init__(512, output_dim, project_out=project_out) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + self.model = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + self.model.model.audio_branch.requires_grad_(False) + self.model.model.audio_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.text_branch + + gc.collect() + torch.cuda.empty_cache() + + def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + + self.model.to(device) + + if isinstance(audios, list) or isinstance(audios, tuple): + audios = torch.cat(audios, dim=0) + + # Convert to mono + mono_audios = audios.mean(dim=1) + + with torch.cuda.amp.autocast(enabled=False): + audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) + + audio_embedding = audio_embedding.unsqueeze(1).to(device) + + return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] + +class T5Conditioner(Conditioner): + + T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl", "t5-v1_1-xl", "google/t5-v1_1-xxl"] + + T5_MODEL_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "t5-v1_1-xl": 2048, + "google/t5-v1_1-xxl": 4096, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + "google/flan-t5-xl": 2048, + "google/flan-t5-xxl": 4096, + } + + def __init__( + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" + super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) + # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) + self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('useful_ckpts', t5_model_name)) + model = T5EncoderModel.from_pretrained(os.path.join('useful_ckpts', t5_model_name)).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + + embeddings = self.proj_out(embeddings.float()) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + return F.normalize(x, dim=-1) if normalize else x + + clip_model.encode_text = new_encode_text.__get__(clip_model) + return clip_model + +class CLIPTextConditioner(Conditioner): + def __init__( + self, + output_dim: int, + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + super().__init__(1024, output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + import open_clip + from open_clip import create_model_from_pretrained + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',cache_dir='useful_ckpts/DFN5B-CLIP-ViT-H-14-384', + return_transform=False).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + model = patch_clip(model) + self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + + encoded = self.tokenizer( + texts + ).to(device) + + # input_ids = encoded["input_ids"].to(device) + # attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model.encode_text( + encoded + ) + + embeddings = self.proj_out(embeddings.float()) + + # embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, torch.ones(embeddings.shape[0], 1).to(device) + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = text_outputs[0] + # pooled_output = text_outputs[1] + # text_features = self.text_projection(pooled_output) + + return last_hidden_state + + clip_model.get_text_features = new_get_text_features.__get__(clip_model) + return clip_model + +class MetaCLIPTextConditioner(Conditioner): + def __init__( + self, + output_dim: int, + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + super().__init__(1024, output_dim, project_out=project_out) + + from transformers import AutoModel + from transformers import AutoProcessor + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.model = AutoModel.from_pretrained("useful_ckpts/metaclip-huge") + self.model = patch_clip(self.model) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + finally: + logging.disable(previous_level) + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + encoded = self.clip_processor(text=texts, return_tensors="pt", padding=True).to(device) + + # input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.set_grad_enabled(self.enable_grad): + embeddings = self.model.get_text_features( + **encoded + ) + + embeddings = self.proj_out(embeddings.float()) + + # embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, torch.ones(embeddings.shape[0],1).to(device) + +class PhonemeConditioner(Conditioner): + """ + A conditioner that turns text into phonemes and embeds them using a lookup table + Only works for English text + + Args: + output_dim: the dimension of the output embeddings + max_length: the maximum number of phonemes to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from g2p_en import G2p + + self.max_length = max_length + + self.g2p = G2p() + + # Reserving 0 for padding, 1 for ignored + self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.phoneme_embedder.to(device) + self.proj_out.to(device) + + batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] + + phoneme_ignore = [" ", *string.punctuation] + + # Remove ignored phonemes and cut to max length + batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] + + # Convert to ids + phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] + + #Pad to match longest and make a mask tensor for the padding + longest = max([len(ids) for ids in phoneme_ids]) + phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] + + phoneme_ids = torch.tensor(phoneme_ids).to(device) + + # Convert to embeddings + phoneme_embeds = self.phoneme_embedder(phoneme_ids) + + phoneme_embeds = self.proj_out(phoneme_embeds) + + return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) + +class TokenizerLUTConditioner(Conditioner): + """ + A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary + + Args: + tokenizer_name: the name of the tokenizer from the Hugging Face transformers library + output_dim: the dimension of the output embeddings + max_length: the maximum length of the text to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from transformers import AutoTokenizer + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + finally: + logging.disable(previous_level) + + self.max_length = max_length + + self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + embeddings = self.token_embedder(input_ids) + + embeddings = self.proj_out(embeddings) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PretransformConditioner(Conditioner): + """ + A conditioner that uses a pretransform's encoder for conditioning + + Args: + pretransform: an instantiated pretransform to use for conditioning + output_dim: the dimension of the output embeddings + """ + def __init__(self, pretransform: Pretransform, output_dim: int): + super().__init__(pretransform.encoded_channels, output_dim) + + self.pretransform = pretransform + + def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.pretransform.to(device) + self.proj_out.to(device) + + if isinstance(audio, list) or isinstance(audio, tuple): + audio = torch.cat(audio, dim=0) + + # Convert audio to pretransform input channels + audio = set_audio_channels(audio, self.pretransform.io_channels) + + latents = self.pretransform.encode(audio) + + latents = self.proj_out(latents) + + return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + +class MultiConditioner(nn.Module): + """ + A module that applies multiple conditioners to an input dictionary based on the keys + + Args: + conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") + default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) + """ + def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): + super().__init__() + + self.conditioners = nn.ModuleDict(conditioners) + self.default_keys = default_keys + + def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: + output = {} + + for key, conditioner in self.conditioners.items(): + condition_key = key + + conditioner_inputs = [] + + for x in batch_metadata: + + if condition_key not in x: + if condition_key in self.default_keys: + condition_key = self.default_keys[condition_key] + else: + raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") + + #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list + if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: + conditioner_input = x[condition_key][0] + + else: + conditioner_input = x[condition_key] + + conditioner_inputs.append(conditioner_input) + + cond_output = conditioner(conditioner_inputs, device) + if len(cond_output) == 1: + output[key] = cond_output[0] + elif len(cond_output) == 2: + output[key] = cond_output + elif len(cond_output) == 4: + output[key] = cond_output[:2] + output[f'{key}_g'] = cond_output[2:] + + return output + +def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: + """ + Create a MultiConditioner from a conditioning config dictionary + + Args: + config: the conditioning config dictionary + device: the device to put the conditioners on + """ + conditioners = {} + cond_dim = config["cond_dim"] + + default_keys = config.get("default_keys", {}) + + for conditioner_info in config["configs"]: + id = conditioner_info["id"] + + conditioner_type = conditioner_info["type"] + + conditioner_config = {"output_dim": cond_dim} + + conditioner_config.update(conditioner_info["config"]) + if conditioner_type == "t5": + conditioners[id] = T5Conditioner(**conditioner_config) + elif conditioner_type == "clap_text": + conditioners[id] = CLAPTextConditioner(**conditioner_config) + elif conditioner_type == "clip_text": + conditioners[id] = CLIPTextConditioner(**conditioner_config) + elif conditioner_type == "metaclip_text": + conditioners[id] = MetaCLIPTextConditioner(**conditioner_config) + elif conditioner_type == "clap_audio": + conditioners[id] = CLAPAudioConditioner(**conditioner_config) + elif conditioner_type == "cond_mlp": + conditioners[id] = Cond_MLP(**conditioner_config) + elif conditioner_type == "global_mlp": + conditioners[id] = Global_MLP(**conditioner_config) + elif conditioner_type == "sync_mlp": + conditioners[id] = Sync_MLP(**conditioner_config) + elif conditioner_type == "cond_mlp_1": + conditioners[id] = Cond_MLP_1(**conditioner_config) + elif conditioner_type == "cond_convmlp": + conditioners[id] = Cond_ConvMLP(**conditioner_config) + elif conditioner_type == "cond_mlp_global": + conditioners[id] = Cond_MLP_Global(**conditioner_config) + elif conditioner_type == "cond_mlp_global_1": + conditioners[id] = Cond_MLP_Global_1(**conditioner_config) + elif conditioner_type == "cond_mlp_global_2": + conditioners[id] = Cond_MLP_Global_2(**conditioner_config) + elif conditioner_type == "video_linear": + conditioners[id] = Video_Linear(**conditioner_config) + elif conditioner_type == "video_global": + conditioners[id] = Video_Global(**conditioner_config) + elif conditioner_type == "video_sync": + conditioners[id] = Video_Sync(**conditioner_config) + elif conditioner_type == "text_linear": + conditioners[id] = Text_Linear(**conditioner_config) + elif conditioner_type == "video_clip": + conditioners[id] = CLIPConditioner(**conditioner_config) + elif conditioner_type == "video_hiera": + conditioners[id] = VideoHieraConditioner(**conditioner_config) + elif conditioner_type == "meta_query": + try: + from .meta_queries.model import MLLMInContext + except ImportError: + raise ImportError("meta_queries module is not available. Cannot create meta_query conditioner.") + conditioners[id] = MLLMInContext(**conditioner_config) + elif conditioner_type == "int": + conditioners[id] = IntConditioner(**conditioner_config) + elif conditioner_type == "number": + conditioners[id] = NumberConditioner(**conditioner_config) + elif conditioner_type == "phoneme": + conditioners[id] = PhonemeConditioner(**conditioner_config) + elif conditioner_type == "lut": + conditioners[id] = TokenizerLUTConditioner(**conditioner_config) + elif conditioner_type == "pretransform": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + from prismaudio_core.factory import create_pretransform_from_config + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) + elif conditioner_type == "mm_unchang": + conditioners[id] = mm_unchang(**conditioner_config) + else: + raise ValueError(f"Unknown conditioner type: {conditioner_type}") + + return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file diff --git a/prismaudio_core/models/diffusion.py b/prismaudio_core/models/diffusion.py new file mode 100644 index 0000000..8233c09 --- /dev/null +++ b/prismaudio_core/models/diffusion.py @@ -0,0 +1,965 @@ +import torch +from torch import nn +from torch.nn import functional as F +from functools import partial +import numpy as np +import typing as tp + +from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes +from .conditioners import MultiConditioner +from .dit import DiffusionTransformer +from .pretransforms import Pretransform + +from .adp import UNetCFG1d, UNet1d + +# Lazy imports for factory functions to avoid circular imports +def _get_create_pretransform_from_config(): + from prismaudio_core.factory import create_pretransform_from_config + return create_pretransform_from_config + +def _get_create_multi_conditioner_from_conditioning_config(): + from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config + return create_multi_conditioner_from_conditioning_config + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionModel(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, t, **kwargs): + raise NotImplementedError() + +class DiffusionModelWrapper(nn.Module): + def __init__( + self, + model: DiffusionModel, + io_channels, + sample_size, + sample_rate, + min_input_length, + pretransform: tp.Optional[Pretransform] = None, + ): + super().__init__() + self.io_channels = io_channels + self.sample_size = sample_size + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.model = model + + if pretransform is not None: + self.pretransform = pretransform + else: + self.pretransform = None + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class ConditionedDiffusionModel(nn.Module): + def __init__(self, + *args, + supports_cross_attention: bool = False, + supports_input_concat: bool = False, + supports_global_cond: bool = False, + supports_prepend_cond: bool = False, + **kwargs): + super().__init__(*args, **kwargs) + self.supports_cross_attention = supports_cross_attention + self.supports_input_concat = supports_input_concat + self.supports_global_cond = supports_global_cond + self.supports_prepend_cond = supports_prepend_cond + + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + cross_attn_cond: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + input_concat_cond: torch.Tensor = None, + global_embed: torch.Tensor = None, + prepend_cond: torch.Tensor = None, + prepend_cond_mask: torch.Tensor = None, + cfg_scale: float = 1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + **kwargs): + raise NotImplementedError() + +class ConditionedDiffusionModelWrapper(nn.Module): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model: ConditionedDiffusionModel, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + zero_init: bool = False, + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + add_cond_ids: tp.List[str] = [], + sync_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.add_cond_ids = add_cond_ids + self.sync_cond_ids = sync_cond_ids + self.min_input_length = min_input_length + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + if zero_init is True: + self.conditioner.apply(_basic_init) + self.model.model.initialize_weights() + + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + add_input = None + sync_input = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = [] + cross_attention_masks = [] + + for key in self.cross_attn_cond_ids: + cross_attn_in, cross_attn_mask = conditioning_tensors[key] + + # Add sequence dimension if it's not there + if len(cross_attn_in.shape) == 2: + cross_attn_in = cross_attn_in.unsqueeze(1) + # cross_attn_mask = cross_attn_mask.unsqueeze(1) + + cross_attention_input.append(cross_attn_in) + cross_attention_masks.append(cross_attn_mask) + # import ipdb + # ipdb.set_trace() + cross_attention_input = torch.cat(cross_attention_input, dim=1) + cross_attention_masks = torch.cat(cross_attention_masks, dim=1) + + if len(self.add_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + add_input = [] + + for key in self.add_cond_ids: + add_in = conditioning_tensors[key][0] + + # Add sequence dimension if it's not there + if len(add_in.shape) == 2: + add_in = add_in.unsqueeze(1) + # add_in = add_in.transpose(1,2) + # add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False) + # add_in = add_in.transpose(1,2) + add_input.append(add_in) + + add_input = torch.cat(add_input, dim=2) + + if len(self.sync_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + sync_input = [] + + for key in self.sync_cond_ids: + sync_in = conditioning_tensors[key][0] + + # Add sequence dimension if it's not there + if len(sync_in.shape) == 2: + sync_in = sync_in.unsqueeze(1) + sync_input.append(sync_in) + + sync_input = torch.cat(sync_input, dim=2) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_conds = [] + for key in self.global_cond_ids: + global_cond_input = conditioning_tensors[key][0] + if len(global_cond_input.shape) == 2: + global_cond_input = global_cond_input.unsqueeze(1) + global_conds.append(global_cond_input) + + # # Concatenate over the channel dimension + # if global_conds[0].shape[-1] == 768: + # global_cond = torch.cat(global_conds, dim=-1) + # else: + # global_cond = sum(global_conds) + global_cond = sum(global_conds) + # global_cond = torch.cat(global_conds, dim=-1) + + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if len(self.input_concat_ids) > 0: + # Concatenate all input concat conditioning inputs over the channel dimension + # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) + input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_conds = [] + prepend_cond_masks = [] + + for key in self.prepend_cond_ids: + prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] + if len(prepend_cond_input.shape) == 2: + prepend_cond_input = prepend_cond_input.unsqueeze(1) + prepend_conds.append(prepend_cond_input) + prepend_cond_masks.append(prepend_cond_mask) + + prepend_cond = torch.cat(prepend_conds, dim=1) + prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_cross_attn_mask": cross_attention_masks, + "negative_global_cond": global_cond, + "negative_input_concat_cond": input_concat_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "cross_attn_mask": cross_attention_masks, + "global_cond": global_cond, + "input_concat_cond": input_concat_cond, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "add_cond": add_input, + "sync_cond": sync_input + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + from prismaudio_core.inference.generation import generate_diffusion_cond + return generate_diffusion_cond(self, *args, **kwargs) + +class UNetCFG1DWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) + + self.model = UNetCFG1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + input_concat_cond=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs): + p = Profiler() + + p.tick("start") + + channels_list = None + if input_concat_cond is not None: + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + embedding=cross_attn_cond, + embedding_mask=cross_attn_mask, + features=global_cond, + channels_list=channels_list, + embedding_scale=cfg_scale, + embedding_mask_proba=cfg_dropout_prob, + batch_cfg=batch_cfg, + rescale_cfg=rescale_cfg, + negative_embedding=negative_cross_attn_cond, + negative_embedding_mask=negative_cross_attn_mask, + **kwargs) + + p.tick("UNetCFG1D forward") + + #print(f"Profiler: {p}") + return outputs + +class UNet1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) + + self.model = UNet1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + global_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + **kwargs): + + channels_list = None + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + features=global_cond, + channels_list=channels_list, + **kwargs) + + return outputs + +class UNet1DUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = UNet1d(in_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class DAU1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) + + self.model = DiffusionAttnUnet1D(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + **kwargs): + + return self.model(x, t, cond = input_concat_cond) + +class DiffusionAttnUnet1D(nn.Module): + def __init__( + self, + io_channels = 2, + depth=14, + n_attn_layers = 6, + channels = [128, 128, 256, 256] + [512] * 10, + cond_dim = 0, + cond_noise_aug = False, + kernel_size = 5, + learned_resample = False, + strides = [2] * 13, + conv_bias = True, + use_snake = False + ): + super().__init__() + + self.cond_noise_aug = cond_noise_aug + + self.io_channels = io_channels + + if self.cond_noise_aug: + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_embed = FourierFeatures(1, 16) + + attn_layer = depth - n_attn_layers + + strides = [1] + strides + + block = nn.Identity() + + conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) + + for i in range(depth, 0, -1): + c = channels[i - 1] + stride = strides[i-1] + if stride > 2 and not learned_resample: + raise ValueError("Must have stride 2 without learned resampling") + + if i > 1: + c_prev = channels[i - 2] + add_attn = i >= attn_layer and n_attn_layers > 0 + block = SkipBlock( + Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), + conv_block(c_prev, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + block, + conv_block(c * 2 if i != depth else c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c_prev), + SelfAttention1d(c_prev, c_prev // + 32) if add_attn else nn.Identity(), + Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") + ) + else: + cond_embed_dim = 16 if not self.cond_noise_aug else 32 + block = nn.Sequential( + conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), + conv_block(c, c, c), + conv_block(c, c, c), + block, + conv_block(c * 2, c, c), + conv_block(c, c, c), + conv_block(c, c, io_channels, is_last=True), + ) + self.net = block + + with torch.no_grad(): + for param in self.net.parameters(): + param *= 0.5 + + def forward(self, x, t, cond=None, cond_aug_scale=None): + + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) + + inputs = [x, timestep_embed] + + if cond is not None: + if cond.shape[2] != x.shape[2]: + cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) + + if self.cond_noise_aug: + # Get a random number between 0 and 1, uniformly sampled + if cond_aug_scale is None: + aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) + else: + aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) + + # Add noise to the conditioning signal + cond = cond + torch.randn_like(cond) * aug_level[:, None, None] + + # Get embedding for noise cond level, reusing timestamp_embed + aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) + + inputs.append(aug_level_embed) + + inputs.append(cond) + + outputs = self.net(torch.cat(inputs, dim=1)) + + return outputs + +class DiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = DiffusionTransformer(*args, **kwargs) + # with torch.no_grad(): + # for param in self.model.parameters(): + # param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + negative_input_concat_cond=None, + global_cond=None, + negative_global_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_mask, + negative_cross_attn_cond=negative_cross_attn_cond, + negative_cross_attn_mask=negative_cross_attn_mask, + input_concat_cond=input_concat_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + global_embed=global_cond, + **kwargs) + +class MMDiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = MMAudio(*args, **kwargs) + + # with torch.no_grad(): + # for param in self.model.parameters(): + # param *= 0.5 + + def forward(self, + x, + t, + clip_f, + sync_f, + text_f, + inpaint_masked_input=None, + t5_features=None, + metaclip_global_text_features=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + # breakpoint() + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + latent=x, + t=t, + clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + inpaint_masked_input=inpaint_masked_input, + t5_features=t5_features, + metaclip_global_text_features=metaclip_global_text_features, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + **kwargs) + +class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + add_cond_ids: tp.List[str] = [], + mm_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.add_cond_ids = add_cond_ids + self.min_input_length = min_input_length + self.mm_cond_ids = mm_cond_ids + + assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper" + assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper" + assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper" + assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper" + assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper" + assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper" + assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper" + assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper" + assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper" + # assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper" + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): + assert negative == False, "negative conditioning is not supported for MMDiTWrapper" + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + add_input = None + inpaint_masked_input = None + t5_features = None + metaclip_global_text_features = None + clip_f = conditioning_tensors["metaclip_features"] + sync_f = conditioning_tensors["sync_features"] + text_f = conditioning_tensors["metaclip_text_features"] + if 'inpaint_masked_input' in conditioning_tensors.keys(): + inpaint_masked_input = conditioning_tensors["inpaint_masked_input"] + if 't5_features' in conditioning_tensors.keys(): + t5_features = conditioning_tensors["t5_features"] + if 'metaclip_global_text_features' in conditioning_tensors.keys(): + metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"] + return { + "clip_f": clip_f, + "sync_f": sync_f, + "text_f": text_f, + "inpaint_masked_input": inpaint_masked_input, + "t5_features": t5_features, + "metaclip_global_text_features": metaclip_global_text_features + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + # breakpoint() + # print(kwargs) + return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + from prismaudio_core.inference.generation import generate_diffusion_cond + return generate_diffusion_cond(self, *args, **kwargs) + +class DiTUncondWrapper(DiffusionModel): + def __init__( + self, + io_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs) + + self.io_channels = io_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + model_type = diffusion_uncond_config.get('type', None) + + diffusion_config = diffusion_uncond_config.get('config', {}) + + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **diffusion_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + **diffusion_config + ) + + elif model_type == "dit": + model = DiTUncondWrapper( + **diffusion_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + + diffusion_config = diffusion_uncond_config.get('diffusion', {}) + model_type = diffusion_config.get('type', None) + model_config = diffusion_config.get("config",{}) + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **model_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + io_channels = io_channels, + **model_config + ) + elif model_type == "dit": + model = DiTUncondWrapper( + **model_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): + + model_config = config["model"] + + model_type = config["model_type"] + + diffusion_config = model_config.get('diffusion', None) + assert diffusion_config is not None, "Must specify diffusion config" + + diffusion_model_type = diffusion_config.get('type', None) + assert diffusion_model_type is not None, "Must specify diffusion model type" + + diffusion_model_config = diffusion_config.get('config', None) + assert diffusion_model_config is not None, "Must specify diffusion model config" + + if diffusion_model_type == 'adp_cfg_1d': + diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) + elif diffusion_model_type == 'adp_1d': + diffusion_model = UNet1DCondWrapper(**diffusion_model_config) + elif diffusion_model_type == 'dit': + diffusion_model = DiTWrapper(**diffusion_model_config) + elif diffusion_model_type == 'mmdit': + diffusion_model = MMDiTWrapper(**diffusion_model_config) + + io_channels = model_config.get('io_channels', None) + assert io_channels is not None, "Must specify io_channels in model config" + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + diffusion_objective = diffusion_config.get('diffusion_objective', 'v') + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + add_cond_ids = diffusion_config.get('add_cond_ids', []) + sync_cond_ids = diffusion_config.get('sync_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + mm_cond_ids = diffusion_config.get('mm_cond_ids', []) + zero_init = diffusion_config.get('zero_init', False) + pretransform = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": + min_input_length *= np.prod(diffusion_model_config["factors"]) + elif diffusion_model_type == "dit": + min_input_length *= diffusion_model.model.patch_size + + # Get the proper wrapper class + + extra_kwargs = {} + + if model_type == "mm_diffusion_cond": + wrapper_fn = MMConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + extra_kwargs["mm_cond_ids"] = mm_cond_ids + + if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': + wrapper_fn = ConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + + elif model_type == "diffusion_prior": + prior_type = model_config.get("prior_type", None) + assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + + if prior_type == "mono_stereo": + from .diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior + + return wrapper_fn( + diffusion_model, + conditioner, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + add_cond_ids=add_cond_ids, + sync_cond_ids=sync_cond_ids, + pretransform=pretransform, + io_channels=io_channels, + zero_init=zero_init, + **extra_kwargs + ) \ No newline at end of file diff --git a/prismaudio_core/models/dit.py b/prismaudio_core/models/dit.py new file mode 100644 index 0000000..5e614d7 --- /dev/null +++ b/prismaudio_core/models/dit.py @@ -0,0 +1,541 @@ +import typing as tp +import math +import torch +# from beartype.typing import Tuple +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer +from .utils import mask_from_frac_lengths, resample + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + cond_ctx_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_cond_type: tp.Literal["global", "input_concat"] = "global", + add_token_dim=0, + sync_token_dim=0, + use_mlp=False, + use_zero_init=False, + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + # Timestep embeddings + self.timestep_cond_type = timestep_cond_type + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + if timestep_cond_type == "global": + timestep_embed_dim = embed_dim + elif timestep_cond_type == "input_concat": + assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat" + input_concat_dim += timestep_embed_dim + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + self.use_mlp = use_mlp + if cond_token_dim > 0: + # Conditioning tokens + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + if add_token_dim > 0: + # Conditioning tokens + add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim + self.to_add_embed = nn.Sequential( + nn.Linear(add_token_dim, add_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(add_embed_dim, add_embed_dim, bias=False) + ) + else: + add_embed_dim = 0 + + if sync_token_dim > 0: + # Conditioning tokens + sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim + self.to_sync_embed = nn.Sequential( + nn.Linear(sync_token_dim, sync_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(sync_embed_dim, sync_embed_dim, bias=False) + ) + else: + sync_embed_dim = 0 + + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True) + self.global_cond_type = global_cond_type + if self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + nn.init.zeros_(self.postprocess_conv.weight) + + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + # if isinstance(module, nn.Conv1d): + # if module.bias is not None: + # nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02) + nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02) + + # Zero-out output layers: + if self.global_cond_type == "adaLN": + for block in self.transformer.layers: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + add_cond=None, + add_masks=None, + sync_cond=None, + return_info=False, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + # reshape from (b, n, c) to (b, c, n) + if input_concat_cond.shape[1] != x.shape[1]: + input_concat_cond = input_concat_cond.transpose(1,2) + # Interpolate input_concat_cond to the same length as x + # if input_concat_cond.shape[1] != x.shape[2]: + # input_concat_cond = input_concat_cond.transpose(1,2) + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + # input_concat_cond = input_concat_cond.transpose(1,2) + # if len(global_embed.shape) == 2: + # global_embed = global_embed.unsqueeze(1) + # global_embed = global_embed + input_concat_cond + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + # import ipdb + # ipdb.set_trace() + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if self.timestep_cond_type == "global": + if global_embed is not None: + if len(global_embed.shape) == 3: + timestep_embed = timestep_embed.unsqueeze(1) + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + elif self.timestep_cond_type == "input_concat": + x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1) + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend" and global_embed is not None: + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + if len(global_embed.shape) == 2: + prepend_inputs = global_embed.unsqueeze(1) + else: + prepend_inputs = global_embed + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + if len(global_embed.shape) == 2: + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + else: + prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + b, seq_len, c = x.shape + + # 计算需要填充的数量 + pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size + + if pad_amount > 0: + # 在时间维度上进行填充 + x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0) + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if add_cond is not None: + # Interpolate add_cond to the same length as x + # if self.use_mlp: + add_cond = self.to_add_embed(add_cond) + if add_cond.shape[1] != x.shape[1]: + add_cond = add_cond.transpose(1,2) + add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False) + add_cond = add_cond.transpose(1,2) + # add_cond = resample(add_cond, x) + + if sync_cond is not None: + sync_cond = self.to_sync_embed(sync_cond) + + if self.transformer_type == "continuous_transformer": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + # 移除之前添加的填充 + if pad_amount > 0: + output = output[:, :, :seq_len] + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + add_cond=None, + sync_cond=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + bsz, a, b = x.shape + model_dtype = next(self.parameters()).dtype + x = x.to(model_dtype) + t = t.to(model_dtype) + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(model_dtype) + + if negative_cross_attn_cond is not None: + negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype) + + if input_concat_cond is not None: + input_concat_cond = input_concat_cond.to(model_dtype) + + if global_embed is not None: + global_embed = global_embed.to(model_dtype) + + if negative_global_embed is not None: + negative_global_embed = negative_global_embed.to(model_dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(model_dtype) + + if add_cond is not None: + add_cond = add_cond.to(model_dtype) + + if sync_cond is not None: + sync_cond = sync_cond.to(model_dtype) + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + + # CFG dropout + if cfg_dropout_prob > 0.0 and cfg_scale == 1.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if add_cond is not None: + null_embed = torch.zeros_like(add_cond, device=add_cond.device) + dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool) + add_cond = torch.where(dropout_mask, null_embed, add_cond) + + if sync_cond is not None: + null_embed = torch.zeros_like(sync_cond, device=sync_cond.device) + dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool) + sync_cond = torch.where(dropout_mask, null_embed, sync_cond) + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + if global_embed is not None and global_embed.shape[0] == bsz: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + elif global_embed is not None: + batch_global_cond = global_embed + else: + batch_global_cond = None + + if input_concat_cond is not None and input_concat_cond.shape[0] == bsz: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + elif input_concat_cond is not None: + batch_input_concat_cond = input_concat_cond + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + elif cross_attn_cond is not None: + batch_cond = cross_attn_cond + else: + batch_cond = None + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None and prepend_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + elif prepend_cond is not None: + batch_prepend_cond = prepend_cond + else: + batch_prepend_cond = None + + batch_add_cond = None + + # Handle CFG for cross-attention conditioning + if add_cond is not None and add_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(add_cond, device=add_cond.device) + + + batch_add_cond = torch.cat([add_cond, null_embed], dim=0) + elif add_cond is not None: + batch_add_cond = add_cond + else: + batch_add_cond = None + + batch_sync_cond = None + + # Handle CFG for cross-attention conditioning + if sync_cond is not None and sync_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(sync_cond, device=sync_cond.device) + + + batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0) + elif sync_cond is not None: + batch_sync_cond = sync_cond + else: + batch_sync_cond = None + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + add_cond = batch_add_cond, + sync_cond = batch_sync_cond, + return_info = return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + add_cond=add_cond, + sync_cond=sync_cond, + mask=mask, + return_info=return_info, + **kwargs + ) \ No newline at end of file diff --git a/prismaudio_core/models/local_attention.py b/prismaudio_core/models/local_attention.py new file mode 100644 index 0000000..893ce11 --- /dev/null +++ b/prismaudio_core/models/local_attention.py @@ -0,0 +1,278 @@ +import torch + +from einops import rearrange +from torch import nn + +from .blocks import AdaRMSNorm +from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py +class ContinuousLocalTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_in = None, + dim_out = None, + causal = False, + local_attn_window_size = 64, + heads = 8, + ff_mult = 2, + cond_dim = 0, + cross_attn_cond_dim = 0, + **kwargs + ): + super().__init__() + + dim_head = dim//heads + + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() + + self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() + + self.local_attn_window_size = local_attn_window_size + + self.cond_dim = cond_dim + + self.cross_attn_cond_dim = cross_attn_cond_dim + + self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) + + for _ in range(depth): + + self.layers.append(nn.ModuleList([ + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + Attention( + dim=dim, + dim_heads=dim_head, + causal=causal, + zero_init_output=True, + natten_kernel_size=local_attn_window_size, + ), + Attention( + dim=dim, + dim_heads=dim_head, + dim_context = cross_attn_cond_dim, + zero_init_output=True + ) if self.cross_attn_cond_dim > 0 else nn.Identity(), + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + FeedForward(dim = dim, mult = ff_mult, no_bias=True) + ])) + + def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): + + x = checkpoint(self.project_in, x) + + if prepend_cond is not None: + x = torch.cat([prepend_cond, x], dim=1) + + pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + + for attn_norm, attn, xattn, ff_norm, ff in self.layers: + + residual = x + if cond is not None: + x = checkpoint(attn_norm, x, cond) + else: + x = checkpoint(attn_norm, x) + + x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual + + if cross_attn_cond is not None: + x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x + + residual = x + + if cond is not None: + x = checkpoint(ff_norm, x, cond) + else: + x = checkpoint(ff_norm, x) + + x = checkpoint(ff, x) + residual + + return checkpoint(self.project_out, x) + +class TransformerDownsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim = 768, + depth = 3, + heads = 12, + downsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.downsample_ratio = downsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size=local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) + + + def forward(self, x): + + x = checkpoint(self.project_in, x) + + # Compute + x = self.transformer(x) + + # Trade sequence length for channels + x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) + + # Project back to embed dim + x = checkpoint(self.project_down, x) + + return x + +class TransformerUpsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim, + depth = 3, + heads = 12, + upsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.upsample_ratio = upsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size = local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) + + def forward(self, x): + + # Project to embed dim + x = checkpoint(self.project_in, x) + + # Project to increase channel dim + x = checkpoint(self.project_up, x) + + # Trade channels for sequence length + x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) + + # Compute + x = self.transformer(x) + + return x + + +class TransformerEncoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [96, 192, 384, 768], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerDownsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + downsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + + return x + + +class TransformerDecoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [768, 384, 192, 96], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerUpsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + upsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + return x \ No newline at end of file diff --git a/prismaudio_core/models/mmmodules/__init__.py b/prismaudio_core/models/mmmodules/__init__.py new file mode 100644 index 0000000..520434f --- /dev/null +++ b/prismaudio_core/models/mmmodules/__init__.py @@ -0,0 +1 @@ +# mmmodules package diff --git a/prismaudio_core/models/mmmodules/model/__init__.py b/prismaudio_core/models/mmmodules/model/__init__.py new file mode 100644 index 0000000..a86d6e1 --- /dev/null +++ b/prismaudio_core/models/mmmodules/model/__init__.py @@ -0,0 +1 @@ +# mmmodules.model package diff --git a/prismaudio_core/models/mmmodules/model/low_level.py b/prismaudio_core/models/mmmodules/model/low_level.py new file mode 100644 index 0000000..c8326a8 --- /dev/null +++ b/prismaudio_core/models/mmmodules/model/low_level.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class ChannelLastConv1d(nn.Conv1d): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = super().forward(x) + x = x.permute(0, 2, 1) + return x + + +# https://github.com/Stability-AI/sd3-ref +class MLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ConvMLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w2 = ChannelLastConv1d(hidden_dim, + dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w3 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/prismaudio_core/models/pqmf.py b/prismaudio_core/models/pqmf.py new file mode 100644 index 0000000..007fdb5 --- /dev/null +++ b/prismaudio_core/models/pqmf.py @@ -0,0 +1,393 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from scipy.optimize import fmin +from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord + +class PQMF(nn.Module): + """ + Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. + Uses polyphase representation which is computationally more efficient for real-time. + + Parameters: + - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. + - num_bands (int): Number of desired frequency bands. It must be a power of 2. + """ + + def __init__(self, attenuation, num_bands): + super(PQMF, self).__init__() + + # Ensure num_bands is a power of 2 + is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) + assert is_power_of_2, "'num_bands' must be a power of 2." + + # Create the prototype filter + prototype_filter = design_prototype_filter(attenuation, num_bands) + filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) + padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) + + # Register filters and settings + self.register_buffer("filter_bank", padded_filter_bank) + self.register_buffer("prototype", prototype_filter) + self.num_bands = num_bands + + def forward(self, signal): + """Decompose the signal into multiple frequency bands.""" + # If signal is not a pytorch tensor of Batch x Channels x Length, convert it + signal = prepare_signal_dimensions(signal) + # The signal length must be a multiple of num_bands. Pad it with zeros. + signal = pad_signal(signal, self.num_bands) + # run it + signal = polyphase_analysis(signal, self.filter_bank) + return apply_alias_cancellation(signal) + + def inverse(self, bands): + """Reconstruct the original signal from the frequency bands.""" + bands = apply_alias_cancellation(bands) + return polyphase_synthesis(bands, self.filter_bank) + + +def prepare_signal_dimensions(signal): + """ + Rearrange signal into Batch x Channels x Length. + + Parameters + ---------- + signal : torch.Tensor or numpy.ndarray + The input signal. + + Returns + ------- + torch.Tensor + Preprocessed signal tensor. + """ + # Convert numpy to torch tensor + if isinstance(signal, np.ndarray): + signal = torch.from_numpy(signal) + + # Ensure tensor + if not isinstance(signal, torch.Tensor): + raise ValueError("Input should be either a numpy array or a PyTorch tensor.") + + # Modify dimension of signal to Batch x Channels x Length + if signal.dim() == 1: + # This is just a mono signal. Unsqueeze to 1 x 1 x Length + signal = signal.unsqueeze(0).unsqueeze(0) + elif signal.dim() == 2: + # This is a multi-channel signal (e.g. stereo) + # Rearrange so that larger dimension (Length) is last + if signal.shape[0] > signal.shape[1]: + signal = signal.T + # Unsqueeze to 1 x Channels x Length + signal = signal.unsqueeze(0) + return signal + +def pad_signal(signal, num_bands): + """ + Pads the signal to make its length divisible by the given number of bands. + + Parameters + ---------- + signal : torch.Tensor + The input signal tensor, where the last dimension represents the signal length. + + num_bands : int + The number of bands by which the signal length should be divisible. + + Returns + ------- + torch.Tensor + The padded signal tensor. If the original signal length was already divisible + by num_bands, returns the original signal unchanged. + """ + remainder = signal.shape[-1] % num_bands + if remainder > 0: + padding_size = num_bands - remainder + signal = nn.functional.pad(signal, (0, padding_size)) + return signal + +def generate_modulated_filter_bank(prototype_filter, num_bands): + """ + Generate a QMF bank of cosine modulated filters based on a given prototype filter. + + Parameters + ---------- + prototype_filter : torch.Tensor + The prototype filter used as the basis for modulation. + num_bands : int + The number of desired subbands or filters. + + Returns + ------- + torch.Tensor + A bank of cosine modulated filters. + """ + + # Initialize indices for modulation. + subband_indices = torch.arange(num_bands).reshape(-1, 1) + + # Calculate the length of the prototype filter. + filter_length = prototype_filter.shape[-1] + + # Generate symmetric time indices centered around zero. + time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) + + # Calculate phase offsets to ensure orthogonality between subbands. + phase_offsets = (-1)**subband_indices * np.pi / 4 + + # Compute the cosine modulation function. + modulation = torch.cos( + (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets + ) + + # Apply modulation to the prototype filter. + modulated_filters = 2 * prototype_filter * modulation + + return modulated_filters + + +def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): + """ + Design a lowpass filter using the Kaiser window. + + Parameters + ---------- + angular_cutoff : float + The angular frequency cutoff of the filter. + attenuation : float + The desired stopband attenuation in decibels (dB). + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The designed lowpass filter coefficients. + """ + + estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) + + # Ensure the estimated length is odd. + estimated_length = 2 * (estimated_length // 2) + 1 + + if filter_length is None: + filter_length = estimated_length + + return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) + + +def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): + """ + Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 + + Parameters + ---------- + angular_cutoff : float + Angular frequency cutoff of the filter. + attenuation : float + Desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. + + Returns + ------- + float + The computed objective (loss) value for the given filter specs. + """ + + filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) + convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") + + return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) + + +def design_prototype_filter(attenuation, num_bands, filter_length=None): + """ + Design the optimal prototype filter for a multiband system given the desired specs. + + Parameters + ---------- + attenuation : float + The desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The optimal prototype filter coefficients. + """ + + optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), + 1 / num_bands, disp=0)[0] + + prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) + return torch.tensor(prototype_filter, dtype=torch.float32) + +def pad_to_nearest_power_of_two(x): + """ + Pads the input tensor 'x' on both sides such that its last dimension + becomes the nearest larger power of two. + + Parameters: + ----------- + x : torch.Tensor + The input tensor to be padded. + + Returns: + -------- + torch.Tensor + The padded tensor. + """ + current_length = x.shape[-1] + target_length = 2**math.ceil(math.log2(current_length)) + + total_padding = target_length - current_length + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + + return nn.functional.pad(x, (left_padding, right_padding)) + +def apply_alias_cancellation(x): + """ + Applies alias cancellation by inverting the sign of every + second element of every second row, starting from the second + row's first element in a tensor. + + This operation helps ensure that the aliasing introduced in + each band during the decomposition will be counteracted during + the reconstruction. + + Parameters: + ----------- + x : torch.Tensor + The input tensor. + + Returns: + -------- + torch.Tensor + Tensor with specific elements' sign inverted for alias cancellation. + """ + + # Create a mask of the same shape as 'x', initialized with all ones + mask = torch.ones_like(x) + + # Update specific elements in the mask to -1 to perform inversion + mask[..., 1::2, ::2] = -1 + + # Apply the mask to the input tensor 'x' + return x * mask + +def ensure_odd_length(tensor): + """ + Pads the last dimension of a tensor to ensure its size is odd. + + Parameters: + ----------- + tensor : torch.Tensor + Input tensor whose last dimension might need padding. + + Returns: + -------- + torch.Tensor + The original tensor if its last dimension was already odd, + or the padded tensor with an odd-sized last dimension. + """ + + last_dim_size = tensor.shape[-1] + + if last_dim_size % 2 == 0: + tensor = nn.functional.pad(tensor, (0, 1)) + + return tensor + +def polyphase_analysis(signal, filter_bank): + """ + Applies the polyphase method to efficiently analyze the signal using a filter bank. + + Parameters: + ----------- + signal : torch.Tensor + Input signal tensor with shape (Batch x Channels x Length). + + filter_bank : torch.Tensor + Filter bank tensor with shape (Bands x Length). + + Returns: + -------- + torch.Tensor + Signal split into sub-bands. (Batch x Channels x Bands x Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange signal for polyphase processing. + # Also combine Batch x Channel into one dimension for now. + #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) + signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) + + # Rearrange the filter bank for matching signal shape + filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) + + # Apply convolution with appropriate padding to maintain spatial dimensions + padding = filter_bank.shape[-1] // 2 + filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) + + # Truncate the last dimension post-convolution to adjust the output shape + filtered_signal = filtered_signal[..., :-1] + # Rearrange the first dimension back into Batch x Channels + filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) + + return filtered_signal + +def polyphase_synthesis(signal, filter_bank): + """ + Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. + + Parameters + ---------- + signal : torch.Tensor + Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). + + filter_bank : torch.Tensor + Analysis filter bank (shape: Bands x Length). + + should_rearrange : bool, optional + Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. + + Returns + ------- + torch.Tensor + Reconstructed signal (shape: Batch x Channels X Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange the filter bank + filter_bank = filter_bank.flip(-1) + filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) + + # Combine Batch x Channels into one dimension for now. + signal = rearrange(signal, "b c n t -> (b c) n t") + + # Apply convolution with appropriate padding + padding_amount = filter_bank.shape[-1] // 2 + 1 + reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) + + # Scale the result + reconstructed_signal = reconstructed_signal[..., :-1] * num_bands + + # Reorganize the output and truncate + reconstructed_signal = reconstructed_signal.flip(1) + reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) + reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] + + return reconstructed_signal \ No newline at end of file diff --git a/prismaudio_core/models/pretransforms.py b/prismaudio_core/models/pretransforms.py new file mode 100644 index 0000000..c9942db --- /dev/null +++ b/prismaudio_core/models/pretransforms.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/prismaudio_core/models/transformer.py b/prismaudio_core/models/transformer.py new file mode 100644 index 0000000..495a34f --- /dev/null +++ b/prismaudio_core/models/transformer.py @@ -0,0 +1,993 @@ +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from typing import Callable, Literal + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + flash_attn_kvpacked_func = None + flash_attn_func = None + +from .utils import compile +try: + import natten +except ImportError: + natten = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return x * (1 + scale) + shift + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast(enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=10.0): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * init_alpha) + self.gamma = nn.Parameter(torch.ones(dim)) + self.beta = nn.Parameter(torch.zeros(dim)) + + def forward(self, x): + x = F.tanh(self.alpha * x) + return self.gamma * x + self.beta + +class RunningInstanceNorm(nn.Module): + def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True): + super().__init__() + self.register_buffer("running_mean", torch.zeros(1,1,dim)) + self.register_buffer("running_std", torch.ones(1,1,dim)) + self.saturate = saturate + self.eps = eps + self.momentum = momentum + self.dim = dim + self.trainable_gain = trainable_gain + if self.trainable_gain: + self.gain = nn.Parameter(torch.ones(1)) + + def _update_stats(self, x): + self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum) + self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps) + + def forward(self, x): + if self.training: + self._update_stats(x) + x = (x - self.running_mean) / self.running_std + if self.saturate: + x = torch.asinh(x) + if self.trainable_gain: + x = x * self.gain + return x + +class LayerNorm(nn.Module): + def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + self.eps = eps + + self.force_fp32 = force_fp32 + + def forward(self, x): + if not self.force_fp32: + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps) + else: + output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps) + return output.to(x.dtype) + +class LayerScale(nn.Module): + def __init__(self, dim, init_val = 1e-5): + super().__init__() + self.scale = nn.Parameter(torch.full([dim], init_val)) + def forward(self, x): + return x * self.scale + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none', + differential = False, + feat_scale = False + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.differential = differential + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + if differential: + self.to_q = nn.Linear(dim, dim * 2, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False) + else: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + if differential: + self.to_qkv = nn.Linear(dim, dim * 5, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']: + raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}') + + self.qk_norm = qk_norm + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + elif self.qk_norm == 'rns': + self.q_norm = nn.RMSNorm(dim_heads) + self.k_norm = nn.RMSNorm(dim_heads) + elif self.qk_norm == 'dyt': + self.q_norm = DynamicTanh(dim_heads) + self.k_norm = DynamicTanh(dim_heads) + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.zeros(dim)) + self.lambda_hf = nn.Parameter(torch.zeros(dim)) + + self.causal = causal + if causal: + print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.') + + @compile + def apply_qk_layernorm(self, q, k): + q_type = q.dtype + k_type = k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + return q, k + + + def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None): + + if self.num_heads != self.kv_heads: + # Repeat interleave kv_heads to match q_heads for grouped query attention + heads_per_kv_head = self.num_heads // self.kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + flash_attn_available = HAS_FLASH_ATTN + if flash_attn_sliding_window is not None and (not flash_attn_available): + print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available") + + if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None: + print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention") + + if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): + print(f"Disabling FlexAttention because causal is set") + flex_attention_block_mask = None + flex_attention_score_mod = None + + if flex_attention_block_mask is not None or flex_attention_score_mod is not None: + out = flex_attention_compiled(q,k,v, + block_mask = flex_attention_block_mask, + score_mod = flex_attention_score_mod) + elif flash_attn_available: + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) + + if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16: + q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1]) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + else: + out = F.scaled_dot_product_attention(q, k, v, is_causal = causal) + return out + + + #@compile + def forward( + self, + x, + context = None, + rotary_pos_emb = None, + causal = None, + flex_attention_block_mask = None, + flex_attention_score_mod = None, + flash_attn_sliding_window = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + if self.differential: + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff)) + q = torch.stack([q, q_diff], dim = 1) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v)) + k = torch.stack([k, k_diff], dim = 1) + else: + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + if self.differential: + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff)) + q = torch.stack([q, q_diff], dim = 1) + k = torch.stack([k, k_diff], dim = 1) + else: + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm != "none": + q, k = self.apply_qk_layernorm(q, k) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype = q.dtype + k_dtype = k.dtype + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + if q.shape[-2] >= k.shape[-2]: + ratio = q.shape[-2] / k.shape[-2] + q_freqs, k_freqs = freqs, ratio * freqs + else: + ratio = k.shape[-2] / q.shape[-2] + q_freqs, k_freqs = ratio * freqs, freqs + q = apply_rotary_pos_emb(q, q_freqs) + k = apply_rotary_pos_emb(k, k_freqs) + q = q.to(v.dtype) + k = k.to(v.dtype) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.differential: + q, q_diff = q.unbind(dim = 1) + k, k_diff = k.unbind(dim = 1) + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out = out - out_diff + else: + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + #@compile + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + add_rope = False, + layer_scale = False, + use_sync_block_film = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = min(dim_heads,dim) + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + if layer_scale and zero_init_branch_outputs: + print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False') + zero_init_branch_outputs = False + + self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) + + self.add_rope = add_rope + + self.self_attn = Attention( + dim, + dim_heads = self.dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.cross_attend = cross_attend + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.cross_attn = Attention( + dim, + dim_heads = self.dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.layer_ix = layer_ix + + self.conformer = None + if conformer: + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) + self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.global_cond_dim = global_cond_dim + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5) + + self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None + + if use_sync_block_film: + self.sync_film_generator = nn.Sequential( + nn.Linear(dim, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift + ) + + @compile + def forward( + self, + x, + context = None, + global_cond=None, + rotary_pos_emb = None, + self_attention_block_mask = None, + self_attention_score_mod = None, + cross_attention_block_mask = None, + cross_attention_score_mod = None, + self_attention_flash_sliding_window = None, + cross_attention_flash_sliding_window = None, + sync_cond = None, + prepend_length=0 + ): + if rotary_pos_emb is None and self.add_rope: + rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2]) + + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + if len(global_cond.shape) == 2: + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1) + else: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window) + x = x * torch.sigmoid(1 - gate_self) + x = self.self_attn_scale(x) + x = x + residual + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + if sync_cond is not None and hasattr(self, 'sync_film_generator'): + scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) + x = x * (1 + scale) + shift + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = self.ff_scale(x) + x = x + residual + + else: + x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)) + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + if sync_cond is not None and hasattr(self, 'sync_film_generator'): + prepend_part = x[:, :prepend_length, :] + audio_part = x[:, prepend_length:, :] + scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) + modulated_audio_part = audio_part * (1 + scale) + shift + x = torch.cat([prepend_part, modulated_audio_part], dim=1) + + x = x + self.ff_scale(self.ff(self.ff_norm(x))) + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + pre_cross_attn_ix=-1, + final_cross_attn_ix=-1, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + num_memory_tokens=0, + sliding_window=None, + use_mlp=False, + use_add_norm=False, + use_gated=False, + use_final_layer=False, + use_zeros=False, + use_conv=False, + use_fusion_mlp=False, + use_film=False, + use_sync_film=False, + use_sync_gated=False, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + if use_mlp: + self.project_in = nn.Sequential( + nn.Linear(dim_in, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim, bias=False) + ) + else: + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + self.video_temporal_conv = None + self.audio_temporal_conv = None + self.fusion_mlp = None + if use_conv: + self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) + self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) + if use_fusion_mlp: + self.fusion_mlp = nn.Sequential( + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens) + + self.adaLN_modulation = None + if global_cond_dim is not None: + if use_final_layer: + self.norm_final = LayerNorm(dim) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + dim, 2 * dim, bias=True + ), + ) + + if use_zeros: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.project_out.weight, 0) + self.global_cond_embedder = nn.Sequential( + nn.Linear(global_cond_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6) + ) + if use_zeros: + nn.init.constant_(self.global_cond_embedder[-1].weight, 0) + nn.init.constant_(self.global_cond_embedder[-1].bias, 0) + nn.init.constant_(self.global_cond_embedder[0].weight, 0) + nn.init.constant_(self.global_cond_embedder[0].bias, 0) + + self.final_cross_attn_ix = final_cross_attn_ix + self.use_gated = use_gated + self.use_film = use_film + self.use_add_norm = use_add_norm + if self.use_add_norm: + self.add_norm = nn.LayerNorm(dim) + if use_gated: + self.gate = nn.Parameter(torch.ones(1, 1, dim)) + + if use_film: + self.film_generator = nn.Sequential( + nn.Linear(dim, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift + ) + else: + self.film_generator = None + + if use_sync_film: + self.sync_film_generator = nn.Sequential( + nn.Linear(dim, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift + ) + else: + self.sync_film_generator = None + if use_sync_gated: + self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim)) + else: + self.sync_gate = None + + self.sliding_window = sliding_window + + for i in range(depth): + should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix)) + # print(f"Layer {i} cross attends: {should_cross_attend}") + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = should_cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + add_cond = None, + sync_cond = None, + global_cond = None, + return_info = False, + use_checkpointing = True, + exit_layer_ix = None, + video_dropout_prob = 0.0, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + model_dtype = next(self.parameters()).dtype + x = x.to(model_dtype) + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + if add_cond is not None: + if self.use_gated: + gate = torch.sigmoid(self.gate) + x = x + gate * add_cond + elif self.use_film: + scale, shift = self.film_generator(add_cond).chunk(2, dim=-1) + x = x * (1 + scale) + shift + else: + x = x + add_cond + + if self.use_add_norm: + x = self.add_norm(x) + if self.fusion_mlp is not None: + x = self.fusion_mlp(x) + + if sync_cond is not None: + if self.sync_film_generator is not None: + scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) + x = x * (1 + scale) + shift + elif self.sync_gate is not None: + gate_value = torch.sigmoid(self.sync_gate) + x = x + gate_value * sync_cond + # else: + # x = x + sync_cond + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if self.num_memory_tokens > 0: + memory_tokens = self.memory_tokens.expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + if global_cond is not None and self.global_cond_embedder is not None: + global_cond_embed = self.global_cond_embedder(global_cond) + else: + global_cond_embed = global_cond + # Iterate over the transformer layers + for layer_ix, layer in enumerate(self.layers): + if use_checkpointing: + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) + else: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + if exit_layer_ix is not None and layer_ix == exit_layer_ix: + x = x[:, self.num_memory_tokens:, :] + + if return_info: + return x, info + + return x + + x = x[:, self.num_memory_tokens:, :] + if global_cond is not None and self.adaLN_modulation is not None: + if len(global_cond.shape) == 2: + global_cond = global_cond.unsqueeze(1) + shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/prismaudio_core/models/utils.py b/prismaudio_core/models/utils.py new file mode 100644 index 0000000..a90f1e2 --- /dev/null +++ b/prismaudio_core/models/utils.py @@ -0,0 +1,177 @@ +import torch +from safetensors.torch import load_file +from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor +#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline +from torch.nn.utils import remove_weight_norm + +def load_ckpt_state_dict(ckpt_path, prefix=None): + if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + # 过滤特定前缀的state_dict + filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict + + return filtered_state_dict + +def remove_weight_norm_from_model(model): + for module in model.modules(): + if hasattr(module, "weight"): + print(f"Removing weight norm from {module}") + remove_weight_norm(module) + + return model + +# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + + if num_samples == 1: + q = torch.empty_like(input).exponential_(1, generator=generator) + return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) + + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +def next_power_of_two(n): + return 2 ** (n - 1).bit_length() + +def next_multiple_of_64(n): + return ((n + 63) // 64) * 64 + + +# mask construction helpers + +def mask_from_start_end_indices( + seq_len: int, + start: Tensor, + end: Tensor +): + assert start.shape == end.shape + device = start.device + + seq = torch.arange(seq_len, device = device, dtype = torch.long) + seq = seq.reshape(*((-1,) * start.ndim), seq_len) + seq = seq.expand(*start.shape, seq_len) + + mask = seq >= start[..., None].long() + mask &= seq < end[..., None].long() + return mask + +def mask_from_frac_lengths( + seq_len: int, + frac_lengths: Tensor +): + device = frac_lengths.device + + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) + start = (max_start * rand).clamp(min = 0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + +def _build_spline(video_feat, video_t, target_t): + # 三次样条插值核心实现 + coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1)) + spline = NaturalCubicSpline(coeffs) + return spline.evaluate(target_t).permute(0,2,1) + +def resample(video_feat, audio_latent): + """ + 9s + video_feat: [B, 72, D] + audio_latent: [B, D', 194] or int + """ + B, Tv, D = video_feat.shape + + if isinstance(audio_latent, torch.Tensor): + # audio_latent is a tensor + if audio_latent.shape[1] != 64: + Ta = audio_latent.shape[1] + else: + Ta = audio_latent.shape[2] + elif isinstance(audio_latent, int): + # audio_latent is an int + Ta = audio_latent + else: + raise TypeError("audio_latent must be either a tensor or an int") + + # 构建时间戳 (关键改进点) + video_time = torch.linspace(0, 9, Tv, device=video_feat.device) + audio_time = torch.linspace(0, 9, Ta, device=video_feat.device) + + # 三维化处理 (Batch, Feature, Time) + video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv] + + # 三次样条插值 + aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta] + return aligned_video.permute(0, 2, 1) # [B, Ta, D] + +import os +enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1" + +def compile(function, *args, **kwargs): + + if enable_torch_compile: + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + return function \ No newline at end of file