6e1186d5bd
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
995 lines
37 KiB
Python
995 lines
37 KiB
Python
from functools import reduce, partial
|
|
from packaging import version
|
|
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, einsum
|
|
from torch.cuda.amp import autocast
|
|
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
|
from typing import Callable, Literal
|
|
|
|
try:
|
|
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
|
HAS_FLASH_ATTN = True
|
|
except ImportError:
|
|
HAS_FLASH_ATTN = False
|
|
flash_attn_kvpacked_func = None
|
|
flash_attn_func = None
|
|
|
|
from .utils import compile
|
|
try:
|
|
import natten
|
|
except ImportError:
|
|
natten = None
|
|
|
|
def checkpoint(function, *args, **kwargs):
|
|
kwargs.setdefault("use_reentrant", False)
|
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
return x * (1 + scale) + shift
|
|
|
|
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
|
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
|
|
|
def create_causal_mask(i, j, device):
|
|
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
|
|
|
def or_reduce(masks):
|
|
head, *body = masks
|
|
for rest in body:
|
|
head = head | rest
|
|
return head
|
|
|
|
# positional embeddings
|
|
|
|
class AbsolutePositionalEmbedding(nn.Module):
|
|
def __init__(self, dim, max_seq_len):
|
|
super().__init__()
|
|
self.scale = dim ** -0.5
|
|
self.max_seq_len = max_seq_len
|
|
self.emb = nn.Embedding(max_seq_len, dim)
|
|
|
|
def forward(self, x, pos = None, seq_start_pos = None):
|
|
seq_len, device = x.shape[1], x.device
|
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
|
|
|
if pos is None:
|
|
pos = torch.arange(seq_len, device = device)
|
|
|
|
if seq_start_pos is not None:
|
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
|
|
|
pos_emb = self.emb(pos)
|
|
pos_emb = pos_emb * self.scale
|
|
return pos_emb
|
|
|
|
class ScaledSinusoidalEmbedding(nn.Module):
|
|
def __init__(self, dim, theta = 10000):
|
|
super().__init__()
|
|
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
|
|
|
half_dim = dim // 2
|
|
freq_seq = torch.arange(half_dim).float() / half_dim
|
|
inv_freq = theta ** -freq_seq
|
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
|
|
|
def forward(self, x, pos = None, seq_start_pos = None):
|
|
seq_len, device = x.shape[1], x.device
|
|
|
|
if pos is None:
|
|
pos = torch.arange(seq_len, device = device)
|
|
|
|
if seq_start_pos is not None:
|
|
pos = pos - seq_start_pos[..., None]
|
|
|
|
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
|
return emb * self.scale
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
use_xpos = False,
|
|
scale_base = 512,
|
|
interpolation_factor = 1.,
|
|
base = 10000,
|
|
base_rescale_factor = 1.
|
|
):
|
|
super().__init__()
|
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
# has some connection to NTK literature
|
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
|
|
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
self.register_buffer('inv_freq', inv_freq)
|
|
|
|
assert interpolation_factor >= 1.
|
|
self.interpolation_factor = interpolation_factor
|
|
|
|
if not use_xpos:
|
|
self.register_buffer('scale', None)
|
|
return
|
|
|
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
|
|
self.scale_base = scale_base
|
|
self.register_buffer('scale', scale)
|
|
|
|
def forward_from_seq_len(self, seq_len):
|
|
device = self.inv_freq.device
|
|
|
|
t = torch.arange(seq_len, device = device)
|
|
return self.forward(t)
|
|
|
|
@autocast(enabled = False)
|
|
def forward(self, t):
|
|
device = self.inv_freq.device
|
|
|
|
t = t.to(torch.float32)
|
|
|
|
t = t / self.interpolation_factor
|
|
|
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
|
|
|
if self.scale is None:
|
|
return freqs, 1.
|
|
|
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
|
scale = self.scale ** rearrange(power, 'n -> n 1')
|
|
scale = torch.cat((scale, scale), dim = -1)
|
|
|
|
return freqs, scale
|
|
|
|
def rotate_half(x):
|
|
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
|
x1, x2 = x.unbind(dim = -2)
|
|
return torch.cat((-x2, x1), dim = -1)
|
|
|
|
@autocast(enabled = False)
|
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
out_dtype = t.dtype
|
|
|
|
# cast to float32 if necessary for numerical stability
|
|
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
|
freqs, t = freqs.to(dtype), t.to(dtype)
|
|
freqs = freqs[-seq_len:, :]
|
|
|
|
if t.ndim == 4 and freqs.ndim == 3:
|
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
|
|
|
# partial rotary embeddings, Wang et al. GPT-J
|
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
|
|
|
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
|
|
|
return torch.cat((t, t_unrotated), dim = -1)
|
|
|
|
# norms
|
|
class DynamicTanh(nn.Module):
|
|
def __init__(self, dim, init_alpha=10.0):
|
|
super().__init__()
|
|
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
self.beta = nn.Parameter(torch.zeros(dim))
|
|
|
|
def forward(self, x):
|
|
x = F.tanh(self.alpha * x)
|
|
return self.gamma * x + self.beta
|
|
|
|
class RunningInstanceNorm(nn.Module):
|
|
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
|
|
super().__init__()
|
|
self.register_buffer("running_mean", torch.zeros(1,1,dim))
|
|
self.register_buffer("running_std", torch.ones(1,1,dim))
|
|
self.saturate = saturate
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.dim = dim
|
|
self.trainable_gain = trainable_gain
|
|
if self.trainable_gain:
|
|
self.gain = nn.Parameter(torch.ones(1))
|
|
|
|
def _update_stats(self, x):
|
|
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
|
|
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
|
|
|
|
def forward(self, x):
|
|
if self.training:
|
|
self._update_stats(x)
|
|
x = (x - self.running_mean) / self.running_std
|
|
if self.saturate:
|
|
x = torch.asinh(x)
|
|
if self.trainable_gain:
|
|
x = x * self.gain
|
|
return x
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
|
|
"""
|
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
|
"""
|
|
super().__init__()
|
|
|
|
if fix_scale:
|
|
self.register_buffer("gamma", torch.ones(dim))
|
|
else:
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
|
|
if bias:
|
|
self.beta = nn.Parameter(torch.zeros(dim))
|
|
else:
|
|
self.register_buffer("beta", torch.zeros(dim))
|
|
|
|
self.eps = eps
|
|
|
|
self.force_fp32 = force_fp32
|
|
|
|
def forward(self, x):
|
|
if not self.force_fp32:
|
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
|
|
else:
|
|
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
|
|
return output.to(x.dtype)
|
|
|
|
class LayerScale(nn.Module):
|
|
def __init__(self, dim, init_val = 1e-5):
|
|
super().__init__()
|
|
self.scale = nn.Parameter(torch.full([dim], init_val))
|
|
def forward(self, x):
|
|
return x * self.scale
|
|
|
|
class GLU(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
activation: Callable,
|
|
use_conv = False,
|
|
conv_kernel_size = 3,
|
|
):
|
|
super().__init__()
|
|
self.act = activation
|
|
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
|
self.use_conv = use_conv
|
|
|
|
def forward(self, x):
|
|
if self.use_conv:
|
|
x = rearrange(x, 'b n d -> b d n')
|
|
x = self.proj(x)
|
|
x = rearrange(x, 'b d n -> b n d')
|
|
else:
|
|
x = self.proj(x)
|
|
|
|
x, gate = x.chunk(2, dim = -1)
|
|
return x * self.act(gate)
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_out = None,
|
|
mult = 4,
|
|
no_bias = False,
|
|
glu = True,
|
|
use_conv = False,
|
|
conv_kernel_size = 3,
|
|
zero_init_output = True,
|
|
):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
|
|
# Default to SwiGLU
|
|
|
|
activation = nn.SiLU()
|
|
|
|
dim_out = dim if dim_out is None else dim_out
|
|
|
|
if glu:
|
|
linear_in = GLU(dim, inner_dim, activation)
|
|
else:
|
|
linear_in = nn.Sequential(
|
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
|
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
|
activation
|
|
)
|
|
|
|
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
|
|
|
# init last linear layer to 0
|
|
if zero_init_output:
|
|
nn.init.zeros_(linear_out.weight)
|
|
if not no_bias:
|
|
nn.init.zeros_(linear_out.bias)
|
|
|
|
|
|
self.ff = nn.Sequential(
|
|
linear_in,
|
|
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
|
linear_out,
|
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.ff(x)
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_heads = 64,
|
|
dim_context = None,
|
|
causal = False,
|
|
zero_init_output=True,
|
|
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
|
|
differential = False,
|
|
feat_scale = False
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_heads = dim_heads
|
|
self.differential = differential
|
|
|
|
dim_kv = dim_context if dim_context is not None else dim
|
|
|
|
self.num_heads = dim // dim_heads
|
|
self.kv_heads = dim_kv // dim_heads
|
|
|
|
if dim_context is not None:
|
|
if differential:
|
|
self.to_q = nn.Linear(dim, dim * 2, bias=False)
|
|
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
|
|
else:
|
|
self.to_q = nn.Linear(dim, dim, bias=False)
|
|
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
|
else:
|
|
if differential:
|
|
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
|
|
else:
|
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
|
|
self.to_out = nn.Linear(dim, dim, bias=False)
|
|
|
|
if zero_init_output:
|
|
nn.init.zeros_(self.to_out.weight)
|
|
|
|
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
|
|
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
|
|
|
|
self.qk_norm = qk_norm
|
|
if self.qk_norm == "ln":
|
|
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
|
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
|
elif self.qk_norm == 'rns':
|
|
self.q_norm = nn.RMSNorm(dim_heads)
|
|
self.k_norm = nn.RMSNorm(dim_heads)
|
|
elif self.qk_norm == 'dyt':
|
|
self.q_norm = DynamicTanh(dim_heads)
|
|
self.k_norm = DynamicTanh(dim_heads)
|
|
|
|
self.sdp_kwargs = dict(
|
|
enable_flash = True,
|
|
enable_math = True,
|
|
enable_mem_efficient = True
|
|
)
|
|
|
|
self.feat_scale = feat_scale
|
|
|
|
if self.feat_scale:
|
|
self.lambda_dc = nn.Parameter(torch.zeros(dim))
|
|
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
|
|
|
self.causal = causal
|
|
if causal:
|
|
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
|
|
|
|
@compile
|
|
def apply_qk_layernorm(self, q, k):
|
|
q_type = q.dtype
|
|
k_type = k.dtype
|
|
q = self.q_norm(q).to(q_type)
|
|
k = self.k_norm(k).to(k_type)
|
|
return q, k
|
|
|
|
|
|
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
|
|
|
|
if self.num_heads != self.kv_heads:
|
|
# Repeat interleave kv_heads to match q_heads for grouped query attention
|
|
heads_per_kv_head = self.num_heads // self.kv_heads
|
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
|
|
|
flash_attn_available = HAS_FLASH_ATTN
|
|
if flash_attn_sliding_window is not None and (not flash_attn_available):
|
|
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
|
|
|
|
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
|
|
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
|
|
|
|
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
|
print(f"Disabling FlexAttention because causal is set")
|
|
flex_attention_block_mask = None
|
|
flex_attention_score_mod = None
|
|
|
|
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
|
raise NotImplementedError(
|
|
"FlexAttention is not available in this build. "
|
|
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
|
)
|
|
elif flash_attn_available:
|
|
fa_dtype_in = q.dtype
|
|
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
|
|
|
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
|
|
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
|
|
|
|
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
|
|
|
|
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
|
else:
|
|
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
|
|
return out
|
|
|
|
|
|
#@compile
|
|
def forward(
|
|
self,
|
|
x,
|
|
context = None,
|
|
rotary_pos_emb = None,
|
|
causal = None,
|
|
flex_attention_block_mask = None,
|
|
flex_attention_score_mod = None,
|
|
flash_attn_sliding_window = None
|
|
):
|
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
|
|
|
kv_input = context if has_context else x
|
|
|
|
if hasattr(self, 'to_q'):
|
|
# Use separate linear projections for q and k/v
|
|
if self.differential:
|
|
q, q_diff = self.to_q(x).chunk(2, dim=-1)
|
|
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
|
|
q = torch.stack([q, q_diff], dim = 1)
|
|
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
|
|
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
|
|
k = torch.stack([k, k_diff], dim = 1)
|
|
else:
|
|
q = self.to_q(x)
|
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
|
else:
|
|
# Use fused linear projection
|
|
if self.differential:
|
|
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
|
|
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
|
|
q = torch.stack([q, q_diff], dim = 1)
|
|
k = torch.stack([k, k_diff], dim = 1)
|
|
else:
|
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
|
|
|
# Normalize q and k for cosine sim attention
|
|
if self.qk_norm == "l2":
|
|
q = F.normalize(q, dim=-1)
|
|
k = F.normalize(k, dim=-1)
|
|
elif self.qk_norm != "none":
|
|
q, k = self.apply_qk_layernorm(q, k)
|
|
|
|
if rotary_pos_emb is not None:
|
|
freqs, _ = rotary_pos_emb
|
|
q_dtype = q.dtype
|
|
k_dtype = k.dtype
|
|
q = q.to(torch.float32)
|
|
k = k.to(torch.float32)
|
|
freqs = freqs.to(torch.float32)
|
|
if q.shape[-2] >= k.shape[-2]:
|
|
ratio = q.shape[-2] / k.shape[-2]
|
|
q_freqs, k_freqs = freqs, ratio * freqs
|
|
else:
|
|
ratio = k.shape[-2] / q.shape[-2]
|
|
q_freqs, k_freqs = ratio * freqs, freqs
|
|
q = apply_rotary_pos_emb(q, q_freqs)
|
|
k = apply_rotary_pos_emb(k, k_freqs)
|
|
q = q.to(v.dtype)
|
|
k = k.to(v.dtype)
|
|
|
|
n, device = q.shape[-2], q.device
|
|
|
|
causal = self.causal if causal is None else causal
|
|
|
|
if n == 1 and causal:
|
|
causal = False
|
|
|
|
if self.differential:
|
|
q, q_diff = q.unbind(dim = 1)
|
|
k, k_diff = k.unbind(dim = 1)
|
|
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
|
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
|
out = out - out_diff
|
|
else:
|
|
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
|
|
|
# merge heads
|
|
out = rearrange(out, ' b h n d -> b n (h d)')
|
|
|
|
# Communicate between heads
|
|
|
|
# with autocast(enabled = False):
|
|
# out_dtype = out.dtype
|
|
# out = out.to(torch.float32)
|
|
# out = self.to_out(out).to(out_dtype)
|
|
out = self.to_out(out)
|
|
|
|
if self.feat_scale:
|
|
out_dc = out.mean(dim=-2, keepdim=True)
|
|
out_hf = out - out_dc
|
|
|
|
# Selectively modulate DC and high frequency components
|
|
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
|
|
|
|
return out
|
|
|
|
class ConformerModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
norm_kwargs = {},
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
|
|
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
|
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
|
self.glu = GLU(dim, dim, nn.SiLU())
|
|
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
|
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
|
self.swish = nn.SiLU()
|
|
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
|
|
|
#@compile
|
|
def forward(self, x):
|
|
x = self.in_norm(x)
|
|
x = rearrange(x, 'b n d -> b d n')
|
|
x = self.pointwise_conv(x)
|
|
x = rearrange(x, 'b d n -> b n d')
|
|
x = self.glu(x)
|
|
x = rearrange(x, 'b n d -> b d n')
|
|
x = self.depthwise_conv(x)
|
|
x = rearrange(x, 'b d n -> b n d')
|
|
x = self.mid_norm(x)
|
|
x = self.swish(x)
|
|
x = rearrange(x, 'b n d -> b d n')
|
|
x = self.pointwise_conv_2(x)
|
|
x = rearrange(x, 'b d n -> b n d')
|
|
|
|
return x
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_heads = 64,
|
|
cross_attend = False,
|
|
dim_context = None,
|
|
global_cond_dim = None,
|
|
causal = False,
|
|
zero_init_branch_outputs = True,
|
|
conformer = False,
|
|
layer_ix = -1,
|
|
remove_norms = False,
|
|
add_rope = False,
|
|
layer_scale = False,
|
|
use_sync_block_film = False,
|
|
attn_kwargs = {},
|
|
ff_kwargs = {},
|
|
norm_kwargs = {}
|
|
):
|
|
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_heads = min(dim_heads,dim)
|
|
self.cross_attend = cross_attend
|
|
self.dim_context = dim_context
|
|
self.causal = causal
|
|
if layer_scale and zero_init_branch_outputs:
|
|
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
|
|
zero_init_branch_outputs = False
|
|
|
|
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
|
|
|
self.add_rope = add_rope
|
|
|
|
self.self_attn = Attention(
|
|
dim,
|
|
dim_heads = self.dim_heads,
|
|
causal = causal,
|
|
zero_init_output=zero_init_branch_outputs,
|
|
**attn_kwargs
|
|
)
|
|
|
|
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
|
|
self.cross_attend = cross_attend
|
|
if cross_attend:
|
|
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
|
self.cross_attn = Attention(
|
|
dim,
|
|
dim_heads = self.dim_heads,
|
|
dim_context=dim_context,
|
|
causal = causal,
|
|
zero_init_output=zero_init_branch_outputs,
|
|
**attn_kwargs
|
|
)
|
|
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
|
|
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
|
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
|
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
|
|
self.layer_ix = layer_ix
|
|
|
|
self.conformer = None
|
|
if conformer:
|
|
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
|
|
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
|
|
self.global_cond_dim = global_cond_dim
|
|
if global_cond_dim is not None:
|
|
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
|
|
|
|
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
|
|
|
|
if use_sync_block_film:
|
|
self.sync_film_generator = nn.Sequential(
|
|
nn.Linear(dim, dim, bias=False),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
|
)
|
|
|
|
@compile
|
|
def forward(
|
|
self,
|
|
x,
|
|
context = None,
|
|
global_cond=None,
|
|
rotary_pos_emb = None,
|
|
self_attention_block_mask = None,
|
|
self_attention_score_mod = None,
|
|
cross_attention_block_mask = None,
|
|
cross_attention_score_mod = None,
|
|
self_attention_flash_sliding_window = None,
|
|
cross_attention_flash_sliding_window = None,
|
|
sync_cond = None,
|
|
prepend_length=0
|
|
):
|
|
if rotary_pos_emb is None and self.add_rope:
|
|
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
|
|
|
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
|
if len(global_cond.shape) == 2:
|
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
|
|
else:
|
|
|
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
|
|
|
|
# self-attention with adaLN
|
|
residual = x
|
|
x = self.pre_norm(x)
|
|
x = x * (1 + scale_self) + shift_self
|
|
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
|
|
x = x * torch.sigmoid(1 - gate_self)
|
|
x = self.self_attn_scale(x)
|
|
x = x + residual
|
|
|
|
if context is not None and self.cross_attend:
|
|
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
|
|
|
if self.conformer is not None:
|
|
x = x + self.conformer_scale(self.conformer(x))
|
|
|
|
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
|
x = x * (1 + scale) + shift
|
|
|
|
# feedforward with adaLN
|
|
residual = x
|
|
x = self.ff_norm(x)
|
|
x = x * (1 + scale_ff) + shift_ff
|
|
x = self.ff(x)
|
|
x = x * torch.sigmoid(1 - gate_ff)
|
|
x = self.ff_scale(x)
|
|
x = x + residual
|
|
|
|
else:
|
|
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
|
|
|
|
if context is not None and self.cross_attend:
|
|
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
|
|
|
if self.conformer is not None:
|
|
x = x + self.conformer_scale(self.conformer(x))
|
|
|
|
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
|
prepend_part = x[:, :prepend_length, :]
|
|
audio_part = x[:, prepend_length:, :]
|
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
|
modulated_audio_part = audio_part * (1 + scale) + shift
|
|
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
|
|
|
|
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
|
|
return x
|
|
|
|
class ContinuousTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
depth,
|
|
*,
|
|
dim_in = None,
|
|
dim_out = None,
|
|
dim_heads = 64,
|
|
cross_attend=False,
|
|
cond_token_dim=None,
|
|
pre_cross_attn_ix=-1,
|
|
final_cross_attn_ix=-1,
|
|
global_cond_dim=None,
|
|
causal=False,
|
|
rotary_pos_emb=True,
|
|
zero_init_branch_outputs=True,
|
|
conformer=False,
|
|
use_sinusoidal_emb=False,
|
|
use_abs_pos_emb=False,
|
|
abs_pos_emb_max_length=10000,
|
|
num_memory_tokens=0,
|
|
sliding_window=None,
|
|
use_mlp=False,
|
|
use_add_norm=False,
|
|
use_gated=False,
|
|
use_final_layer=False,
|
|
use_zeros=False,
|
|
use_conv=False,
|
|
use_fusion_mlp=False,
|
|
use_film=False,
|
|
use_sync_film=False,
|
|
use_sync_gated=False,
|
|
**kwargs
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.depth = depth
|
|
self.causal = causal
|
|
self.layers = nn.ModuleList([])
|
|
if use_mlp:
|
|
self.project_in = nn.Sequential(
|
|
nn.Linear(dim_in, dim, bias=False),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim, bias=False)
|
|
)
|
|
else:
|
|
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
|
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
|
self.video_temporal_conv = None
|
|
self.audio_temporal_conv = None
|
|
self.fusion_mlp = None
|
|
if use_conv:
|
|
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
|
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
|
if use_fusion_mlp:
|
|
self.fusion_mlp = nn.Sequential(
|
|
nn.Linear(dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim)
|
|
)
|
|
|
|
if rotary_pos_emb:
|
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
|
else:
|
|
self.rotary_pos_emb = None
|
|
self.num_memory_tokens = num_memory_tokens
|
|
if num_memory_tokens > 0:
|
|
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
|
|
|
self.use_sinusoidal_emb = use_sinusoidal_emb
|
|
if use_sinusoidal_emb:
|
|
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
|
|
|
self.use_abs_pos_emb = use_abs_pos_emb
|
|
if use_abs_pos_emb:
|
|
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
|
|
|
|
self.adaLN_modulation = None
|
|
if global_cond_dim is not None:
|
|
if use_final_layer:
|
|
self.norm_final = LayerNorm(dim)
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(
|
|
dim, 2 * dim, bias=True
|
|
),
|
|
)
|
|
|
|
if use_zeros:
|
|
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
|
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
|
nn.init.constant_(self.project_out.weight, 0)
|
|
self.global_cond_embedder = nn.Sequential(
|
|
nn.Linear(global_cond_dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 6)
|
|
)
|
|
if use_zeros:
|
|
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
|
|
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
|
|
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
|
|
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
|
|
|
|
self.final_cross_attn_ix = final_cross_attn_ix
|
|
self.use_gated = use_gated
|
|
self.use_film = use_film
|
|
self.use_add_norm = use_add_norm
|
|
if self.use_add_norm:
|
|
self.add_norm = nn.LayerNorm(dim)
|
|
if use_gated:
|
|
self.gate = nn.Parameter(torch.ones(1, 1, dim))
|
|
|
|
if use_film:
|
|
self.film_generator = nn.Sequential(
|
|
nn.Linear(dim, dim, bias=False),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
|
)
|
|
else:
|
|
self.film_generator = None
|
|
|
|
if use_sync_film:
|
|
self.sync_film_generator = nn.Sequential(
|
|
nn.Linear(dim, dim, bias=False),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
|
)
|
|
else:
|
|
self.sync_film_generator = None
|
|
if use_sync_gated:
|
|
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
|
|
else:
|
|
self.sync_gate = None
|
|
|
|
self.sliding_window = sliding_window
|
|
|
|
for i in range(depth):
|
|
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
|
|
# print(f"Layer {i} cross attends: {should_cross_attend}")
|
|
self.layers.append(
|
|
TransformerBlock(
|
|
dim,
|
|
dim_heads = dim_heads,
|
|
cross_attend = should_cross_attend,
|
|
dim_context = cond_token_dim,
|
|
global_cond_dim = global_cond_dim,
|
|
causal = causal,
|
|
zero_init_branch_outputs = zero_init_branch_outputs,
|
|
conformer=conformer,
|
|
layer_ix=i,
|
|
**kwargs
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
mask = None,
|
|
prepend_embeds = None,
|
|
prepend_mask = None,
|
|
add_cond = None,
|
|
sync_cond = None,
|
|
global_cond = None,
|
|
return_info = False,
|
|
use_checkpointing = True,
|
|
exit_layer_ix = None,
|
|
video_dropout_prob = 0.0,
|
|
**kwargs
|
|
):
|
|
batch, seq, device = *x.shape[:2], x.device
|
|
model_dtype = next(self.parameters()).dtype
|
|
x = x.to(model_dtype)
|
|
|
|
info = {
|
|
"hidden_states": [],
|
|
}
|
|
|
|
x = self.project_in(x)
|
|
if add_cond is not None:
|
|
if self.use_gated:
|
|
gate = torch.sigmoid(self.gate)
|
|
x = x + gate * add_cond
|
|
elif self.use_film:
|
|
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
|
|
x = x * (1 + scale) + shift
|
|
else:
|
|
x = x + add_cond
|
|
|
|
if self.use_add_norm:
|
|
x = self.add_norm(x)
|
|
if self.fusion_mlp is not None:
|
|
x = self.fusion_mlp(x)
|
|
|
|
if sync_cond is not None:
|
|
if self.sync_film_generator is not None:
|
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
|
x = x * (1 + scale) + shift
|
|
elif self.sync_gate is not None:
|
|
gate_value = torch.sigmoid(self.sync_gate)
|
|
x = x + gate_value * sync_cond
|
|
# else:
|
|
# x = x + sync_cond
|
|
|
|
if prepend_embeds is not None:
|
|
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
|
|
|
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
|
|
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
|
|
|
if self.num_memory_tokens > 0:
|
|
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
|
|
x = torch.cat((memory_tokens, x), dim=1)
|
|
|
|
if self.rotary_pos_emb is not None:
|
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
|
else:
|
|
rotary_pos_emb = None
|
|
|
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
|
x = x + self.pos_emb(x)
|
|
|
|
if global_cond is not None and self.global_cond_embedder is not None:
|
|
global_cond_embed = self.global_cond_embedder(global_cond)
|
|
else:
|
|
global_cond_embed = global_cond
|
|
# Iterate over the transformer layers
|
|
for layer_ix, layer in enumerate(self.layers):
|
|
if use_checkpointing:
|
|
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
|
else:
|
|
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
|
|
|
if return_info:
|
|
info["hidden_states"].append(x)
|
|
|
|
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
|
|
x = x[:, self.num_memory_tokens:, :]
|
|
|
|
if return_info:
|
|
return x, info
|
|
|
|
return x
|
|
|
|
x = x[:, self.num_memory_tokens:, :]
|
|
if global_cond is not None and self.adaLN_modulation is not None:
|
|
if len(global_cond.shape) == 2:
|
|
global_cond = global_cond.unsqueeze(1)
|
|
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
|
|
x = modulate(self.norm_final(x), shift, scale)
|
|
x = self.project_out(x)
|
|
|
|
if return_info:
|
|
return x, info
|
|
|
|
return x
|