from functools import reduce, partial from packaging import version from einops import rearrange, repeat from einops.layers.torch import Rearrange import torch import torch.nn.functional as F from torch import nn, einsum from torch.cuda.amp import autocast from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP from typing import Callable, Literal try: from flash_attn import flash_attn_func, flash_attn_kvpacked_func HAS_FLASH_ATTN = True except ImportError: HAS_FLASH_ATTN = False flash_attn_kvpacked_func = None flash_attn_func = None from .utils import compile, checkpoint try: import natten except ImportError: natten = None def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): return x * (1 + scale) + shift # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt def create_causal_mask(i, j, device): return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) def or_reduce(masks): head, *body = masks for rest in body: head = head | rest return head # positional embeddings class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() self.scale = dim ** -0.5 self.max_seq_len = max_seq_len self.emb = nn.Embedding(max_seq_len, dim) def forward(self, x, pos = None, seq_start_pos = None): seq_len, device = x.shape[1], x.device assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' if pos is None: pos = torch.arange(seq_len, device = device) if seq_start_pos is not None: pos = (pos - seq_start_pos[..., None]).clamp(min = 0) pos_emb = self.emb(pos) pos_emb = pos_emb * self.scale return pos_emb class ScaledSinusoidalEmbedding(nn.Module): def __init__(self, dim, theta = 10000): super().__init__() assert (dim % 2) == 0, 'dimension must be divisible by 2' self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) half_dim = dim // 2 freq_seq = torch.arange(half_dim).float() / half_dim inv_freq = theta ** -freq_seq self.register_buffer('inv_freq', inv_freq, persistent = False) def forward(self, x, pos = None, seq_start_pos = None): seq_len, device = x.shape[1], x.device if pos is None: pos = torch.arange(seq_len, device = device) if seq_start_pos is not None: pos = pos - seq_start_pos[..., None] emb = einsum('i, j -> i j', pos, self.inv_freq) emb = torch.cat((emb.sin(), emb.cos()), dim = -1) return emb * self.scale class RotaryEmbedding(nn.Module): def __init__( self, dim, use_xpos = False, scale_base = 512, interpolation_factor = 1., base = 10000, base_rescale_factor = 1. ): super().__init__() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ base *= base_rescale_factor ** (dim / (dim - 2)) inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) assert interpolation_factor >= 1. self.interpolation_factor = interpolation_factor if not use_xpos: self.register_buffer('scale', None) return scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.scale_base = scale_base self.register_buffer('scale', scale) def forward_from_seq_len(self, seq_len): device = self.inv_freq.device t = torch.arange(seq_len, device = device) return self.forward(t) @autocast(enabled = False) def forward(self, t): device = self.inv_freq.device t = t.to(torch.float32) t = t / self.interpolation_factor freqs = torch.einsum('i , j -> i j', t, self.inv_freq) freqs = torch.cat((freqs, freqs), dim = -1) if self.scale is None: return freqs, 1. power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base scale = self.scale ** rearrange(power, 'n -> n 1') scale = torch.cat((scale, scale), dim = -1) return freqs, scale def rotate_half(x): x = rearrange(x, '... (j d) -> ... j d', j = 2) x1, x2 = x.unbind(dim = -2) return torch.cat((-x2, x1), dim = -1) @autocast(enabled = False) def apply_rotary_pos_emb(t, freqs, scale = 1): out_dtype = t.dtype # cast to float32 if necessary for numerical stability dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) rot_dim, seq_len = freqs.shape[-1], t.shape[-2] freqs, t = freqs.to(dtype), t.to(dtype) freqs = freqs[-seq_len:, :] if t.ndim == 4 and freqs.ndim == 3: freqs = rearrange(freqs, 'b n d -> b 1 n d') # partial rotary embeddings, Wang et al. GPT-J t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) return torch.cat((t, t_unrotated), dim = -1) # norms class DynamicTanh(nn.Module): def __init__(self, dim, init_alpha=10.0): super().__init__() self.alpha = nn.Parameter(torch.ones(1) * init_alpha) self.gamma = nn.Parameter(torch.ones(dim)) self.beta = nn.Parameter(torch.zeros(dim)) def forward(self, x): x = F.tanh(self.alpha * x) return self.gamma * x + self.beta class RunningInstanceNorm(nn.Module): def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True): super().__init__() self.register_buffer("running_mean", torch.zeros(1,1,dim)) self.register_buffer("running_std", torch.ones(1,1,dim)) self.saturate = saturate self.eps = eps self.momentum = momentum self.dim = dim self.trainable_gain = trainable_gain if self.trainable_gain: self.gain = nn.Parameter(torch.ones(1)) def _update_stats(self, x): self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum) self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps) def forward(self, x): if self.training: self._update_stats(x) x = (x - self.running_mean) / self.running_std if self.saturate: x = torch.asinh(x) if self.trainable_gain: x = x * self.gain return x class LayerNorm(nn.Module): def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5): """ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less """ super().__init__() if fix_scale: self.register_buffer("gamma", torch.ones(dim)) else: self.gamma = nn.Parameter(torch.ones(dim)) if bias: self.beta = nn.Parameter(torch.zeros(dim)) else: self.register_buffer("beta", torch.zeros(dim)) self.eps = eps self.force_fp32 = force_fp32 def forward(self, x): if not self.force_fp32: return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps) else: output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps) return output.to(x.dtype) class LayerScale(nn.Module): def __init__(self, dim, init_val = 1e-5): super().__init__() self.scale = nn.Parameter(torch.full([dim], init_val)) def forward(self, x): return x * self.scale class GLU(nn.Module): def __init__( self, dim_in, dim_out, activation: Callable, use_conv = False, conv_kernel_size = 3, ): super().__init__() self.act = activation self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) self.use_conv = use_conv def forward(self, x): if self.use_conv: x = rearrange(x, 'b n d -> b d n') x = self.proj(x) x = rearrange(x, 'b d n -> b n d') else: x = self.proj(x) x, gate = x.chunk(2, dim = -1) return x * self.act(gate) class FeedForward(nn.Module): def __init__( self, dim, dim_out = None, mult = 4, no_bias = False, glu = True, use_conv = False, conv_kernel_size = 3, zero_init_output = True, ): super().__init__() inner_dim = int(dim * mult) # Default to SwiGLU activation = nn.SiLU() dim_out = dim if dim_out is None else dim_out if glu: linear_in = GLU(dim, inner_dim, activation) else: linear_in = nn.Sequential( Rearrange('b n d -> b d n') if use_conv else nn.Identity(), nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), Rearrange('b n d -> b d n') if use_conv else nn.Identity(), activation ) linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) # init last linear layer to 0 if zero_init_output: nn.init.zeros_(linear_out.weight) if not no_bias: nn.init.zeros_(linear_out.bias) self.ff = nn.Sequential( linear_in, Rearrange('b d n -> b n d') if use_conv else nn.Identity(), linear_out, Rearrange('b n d -> b d n') if use_conv else nn.Identity(), ) def forward(self, x): return self.ff(x) class Attention(nn.Module): def __init__( self, dim, dim_heads = 64, dim_context = None, causal = False, zero_init_output=True, qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none', differential = False, feat_scale = False ): super().__init__() self.dim = dim self.dim_heads = dim_heads self.differential = differential dim_kv = dim_context if dim_context is not None else dim self.num_heads = dim // dim_heads self.kv_heads = dim_kv // dim_heads if dim_context is not None: if differential: self.to_q = nn.Linear(dim, dim * 2, bias=False) self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False) else: self.to_q = nn.Linear(dim, dim, bias=False) self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) else: if differential: self.to_qkv = nn.Linear(dim, dim * 5, bias=False) else: self.to_qkv = nn.Linear(dim, dim * 3, bias=False) self.to_out = nn.Linear(dim, dim, bias=False) if zero_init_output: nn.init.zeros_(self.to_out.weight) if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']: raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}') self.qk_norm = qk_norm if self.qk_norm == "ln": self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) elif self.qk_norm == 'rns': self.q_norm = nn.RMSNorm(dim_heads) self.k_norm = nn.RMSNorm(dim_heads) elif self.qk_norm == 'dyt': self.q_norm = DynamicTanh(dim_heads) self.k_norm = DynamicTanh(dim_heads) self.sdp_kwargs = dict( enable_flash = True, enable_math = True, enable_mem_efficient = True ) self.feat_scale = feat_scale if self.feat_scale: self.lambda_dc = nn.Parameter(torch.zeros(dim)) self.lambda_hf = nn.Parameter(torch.zeros(dim)) self.causal = causal @compile def apply_qk_layernorm(self, q, k): q_type = q.dtype k_type = k.dtype q = self.q_norm(q).to(q_type) k = self.k_norm(k).to(k_type) return q, k def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None): if self.num_heads != self.kv_heads: # Repeat interleave kv_heads to match q_heads for grouped query attention heads_per_kv_head = self.num_heads // self.kv_heads k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) flash_attn_available = HAS_FLASH_ATTN if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): flex_attention_block_mask = None flex_attention_score_mod = None if flex_attention_block_mask is not None or flex_attention_score_mod is not None: raise NotImplementedError( "FlexAttention is not available in this build. " "flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments." ) elif flash_attn_available: fa_dtype_in = q.dtype q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16: q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v)) out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1]) out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') else: out = F.scaled_dot_product_attention(q, k, v, is_causal = causal) return out #@compile def forward( self, x, context = None, rotary_pos_emb = None, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None kv_input = context if has_context else x if hasattr(self, 'to_q'): # Use separate linear projections for q and k/v if self.differential: q, q_diff = self.to_q(x).chunk(2, dim=-1) q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff)) q = torch.stack([q, q_diff], dim = 1) k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v)) k = torch.stack([k, k_diff], dim = 1) else: q = self.to_q(x) q = rearrange(q, 'b n (h d) -> b h n d', h = h) k, v = self.to_kv(kv_input).chunk(2, dim=-1) k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) else: # Use fused linear projection if self.differential: q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff)) q = torch.stack([q, q_diff], dim = 1) k = torch.stack([k, k_diff], dim = 1) else: q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # Normalize q and k for cosine sim attention if self.qk_norm == "l2": q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) elif self.qk_norm != "none": q, k = self.apply_qk_layernorm(q, k) if rotary_pos_emb is not None: freqs, _ = rotary_pos_emb q_dtype = q.dtype k_dtype = k.dtype q = q.to(torch.float32) k = k.to(torch.float32) freqs = freqs.to(torch.float32) if q.shape[-2] >= k.shape[-2]: ratio = q.shape[-2] / k.shape[-2] q_freqs, k_freqs = freqs, ratio * freqs else: ratio = k.shape[-2] / q.shape[-2] q_freqs, k_freqs = ratio * freqs, freqs q = apply_rotary_pos_emb(q, q_freqs) k = apply_rotary_pos_emb(k, k_freqs) q = q.to(v.dtype) k = k.to(v.dtype) n, device = q.shape[-2], q.device causal = self.causal if causal is None else causal if n == 1 and causal: causal = False if self.differential: q, q_diff = q.unbind(dim = 1) k, k_diff = k.unbind(dim = 1) out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) out = out - out_diff else: out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) # merge heads out = rearrange(out, ' b h n d -> b n (h d)') # Communicate between heads # with autocast(enabled = False): # out_dtype = out.dtype # out = out.to(torch.float32) # out = self.to_out(out).to(out_dtype) out = self.to_out(out) if self.feat_scale: out_dc = out.mean(dim=-2, keepdim=True) out_hf = out - out_dc # Selectively modulate DC and high frequency components out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf return out class ConformerModule(nn.Module): def __init__( self, dim, norm_kwargs = {}, ): super().__init__() self.dim = dim self.in_norm = LayerNorm(dim, **norm_kwargs) self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) self.glu = GLU(dim, dim, nn.SiLU()) self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm self.swish = nn.SiLU() self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) #@compile def forward(self, x): x = self.in_norm(x) x = rearrange(x, 'b n d -> b d n') x = self.pointwise_conv(x) x = rearrange(x, 'b d n -> b n d') x = self.glu(x) x = rearrange(x, 'b n d -> b d n') x = self.depthwise_conv(x) x = rearrange(x, 'b d n -> b n d') x = self.mid_norm(x) x = self.swish(x) x = rearrange(x, 'b n d -> b d n') x = self.pointwise_conv_2(x) x = rearrange(x, 'b d n -> b n d') return x class TransformerBlock(nn.Module): def __init__( self, dim, dim_heads = 64, cross_attend = False, dim_context = None, global_cond_dim = None, causal = False, zero_init_branch_outputs = True, conformer = False, layer_ix = -1, remove_norms = False, add_rope = False, layer_scale = False, use_sync_block_film = False, attn_kwargs = {}, ff_kwargs = {}, norm_kwargs = {} ): super().__init__() self.dim = dim self.dim_heads = min(dim_heads,dim) self.cross_attend = cross_attend self.dim_context = dim_context self.causal = causal if layer_scale and zero_init_branch_outputs: zero_init_branch_outputs = False self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) self.add_rope = add_rope self.self_attn = Attention( dim, dim_heads = self.dim_heads, causal = causal, zero_init_output=zero_init_branch_outputs, **attn_kwargs ) self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() self.cross_attend = cross_attend if cross_attend: self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) self.cross_attn = Attention( dim, dim_heads = self.dim_heads, dim_context=dim_context, causal = causal, zero_init_output=zero_init_branch_outputs, **attn_kwargs ) self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity() self.layer_ix = layer_ix self.conformer = None if conformer: self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity() self.global_cond_dim = global_cond_dim if global_cond_dim is not None: self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5) self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None if use_sync_block_film: self.sync_film_generator = nn.Sequential( nn.Linear(dim, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift ) @compile def forward( self, x, context = None, global_cond=None, rotary_pos_emb = None, self_attention_block_mask = None, self_attention_score_mod = None, cross_attention_block_mask = None, cross_attention_score_mod = None, self_attention_flash_sliding_window = None, cross_attention_flash_sliding_window = None, sync_cond = None, prepend_length=0 ): if rotary_pos_emb is None and self.add_rope: rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2]) if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: if len(global_cond.shape) == 2: scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1) else: scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1) # self-attention with adaLN residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window) x = x * torch.sigmoid(1 - gate_self) x = self.self_attn_scale(x) x = x + residual if context is not None and self.cross_attend: x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) if self.conformer is not None: x = x + self.conformer_scale(self.conformer(x)) if sync_cond is not None and hasattr(self, 'sync_film_generator'): scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) x = x * (1 + scale) + shift # feedforward with adaLN residual = x x = self.ff_norm(x) x = x * (1 + scale_ff) + shift_ff x = self.ff(x) x = x * torch.sigmoid(1 - gate_ff) x = self.ff_scale(x) x = x + residual else: x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)) if context is not None and self.cross_attend: x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) if self.conformer is not None: x = x + self.conformer_scale(self.conformer(x)) if sync_cond is not None and hasattr(self, 'sync_film_generator'): prepend_part = x[:, :prepend_length, :] audio_part = x[:, prepend_length:, :] scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) modulated_audio_part = audio_part * (1 + scale) + shift x = torch.cat([prepend_part, modulated_audio_part], dim=1) x = x + self.ff_scale(self.ff(self.ff_norm(x))) return x class ContinuousTransformer(nn.Module): def __init__( self, dim, depth, *, dim_in = None, dim_out = None, dim_heads = 64, cross_attend=False, cond_token_dim=None, pre_cross_attn_ix=-1, final_cross_attn_ix=-1, global_cond_dim=None, causal=False, rotary_pos_emb=True, zero_init_branch_outputs=True, conformer=False, use_sinusoidal_emb=False, use_abs_pos_emb=False, abs_pos_emb_max_length=10000, num_memory_tokens=0, sliding_window=None, use_mlp=False, use_add_norm=False, use_gated=False, use_final_layer=False, use_zeros=False, use_conv=False, use_fusion_mlp=False, use_film=False, use_sync_film=False, use_sync_gated=False, **kwargs ): super().__init__() self.dim = dim self.depth = depth self.causal = causal self.layers = nn.ModuleList([]) if use_mlp: self.project_in = nn.Sequential( nn.Linear(dim_in, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim, bias=False) ) else: self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() self.video_temporal_conv = None self.audio_temporal_conv = None self.fusion_mlp = None if use_conv: self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) if use_fusion_mlp: self.fusion_mlp = nn.Sequential( nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim) ) if rotary_pos_emb: self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) else: self.rotary_pos_emb = None self.num_memory_tokens = num_memory_tokens if num_memory_tokens > 0: self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) self.use_sinusoidal_emb = use_sinusoidal_emb if use_sinusoidal_emb: self.pos_emb = ScaledSinusoidalEmbedding(dim) self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens) self.adaLN_modulation = None if global_cond_dim is not None: if use_final_layer: self.norm_final = LayerNorm(dim) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear( dim, 2 * dim, bias=True ), ) if use_zeros: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.project_out.weight, 0) self.global_cond_embedder = nn.Sequential( nn.Linear(global_cond_dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6) ) if use_zeros: nn.init.constant_(self.global_cond_embedder[-1].weight, 0) nn.init.constant_(self.global_cond_embedder[-1].bias, 0) nn.init.constant_(self.global_cond_embedder[0].weight, 0) nn.init.constant_(self.global_cond_embedder[0].bias, 0) self.final_cross_attn_ix = final_cross_attn_ix self.use_gated = use_gated self.use_film = use_film self.use_add_norm = use_add_norm if self.use_add_norm: self.add_norm = nn.LayerNorm(dim) if use_gated: self.gate = nn.Parameter(torch.ones(1, 1, dim)) if use_film: self.film_generator = nn.Sequential( nn.Linear(dim, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift ) else: self.film_generator = None if use_sync_film: self.sync_film_generator = nn.Sequential( nn.Linear(dim, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift ) else: self.sync_film_generator = None if use_sync_gated: self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim)) else: self.sync_gate = None self.sliding_window = sliding_window for i in range(depth): should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix)) # print(f"Layer {i} cross attends: {should_cross_attend}") self.layers.append( TransformerBlock( dim, dim_heads = dim_heads, cross_attend = should_cross_attend, dim_context = cond_token_dim, global_cond_dim = global_cond_dim, causal = causal, zero_init_branch_outputs = zero_init_branch_outputs, conformer=conformer, layer_ix=i, **kwargs ) ) def forward( self, x, mask = None, prepend_embeds = None, prepend_mask = None, add_cond = None, sync_cond = None, global_cond = None, return_info = False, use_checkpointing = True, exit_layer_ix = None, video_dropout_prob = 0.0, **kwargs ): batch, seq, device = *x.shape[:2], x.device model_dtype = next(self.parameters()).dtype x = x.to(model_dtype) prepend_length = 0 info = { "hidden_states": [], } x = self.project_in(x) if add_cond is not None: if self.use_gated: gate = torch.sigmoid(self.gate) x = x + gate * add_cond elif self.use_film: scale, shift = self.film_generator(add_cond).chunk(2, dim=-1) x = x * (1 + scale) + shift else: x = x + add_cond if self.use_add_norm: x = self.add_norm(x) if self.fusion_mlp is not None: x = self.fusion_mlp(x) if sync_cond is not None: if self.sync_film_generator is not None: scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) x = x * (1 + scale) + shift elif self.sync_gate is not None: gate_value = torch.sigmoid(self.sync_gate) x = x + gate_value * sync_cond # else: # x = x + sync_cond if prepend_embeds is not None: prepend_length, prepend_dim = prepend_embeds.shape[1:] assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' x = torch.cat((prepend_embeds, x), dim = -2) if self.num_memory_tokens > 0: memory_tokens = self.memory_tokens.expand(batch, -1, -1) x = torch.cat((memory_tokens, x), dim=1) if self.rotary_pos_emb is not None: rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) else: rotary_pos_emb = None if self.use_sinusoidal_emb or self.use_abs_pos_emb: x = x + self.pos_emb(x) if global_cond is not None and self.global_cond_embedder is not None: global_cond_embed = self.global_cond_embedder(global_cond) else: global_cond_embed = global_cond # Iterate over the transformer layers for layer_ix, layer in enumerate(self.layers): if use_checkpointing: x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) else: x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) if return_info: info["hidden_states"].append(x) if exit_layer_ix is not None and layer_ix == exit_layer_ix: x = x[:, self.num_memory_tokens:, :] if return_info: return x, info return x x = x[:, self.num_memory_tokens:, :] if global_cond is not None and self.adaLN_modulation is not None: if len(global_cond.shape) == 2: global_cond = global_cond.unsqueeze(1) shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.project_out(x) if return_info: return x, info return x