Files
ComfyUI-SelVA/selva_core/model/transformer_layers.py
T
Ethanfel 6bc3fd6443 chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:18:09 +02:00

427 lines
16 KiB
Python

from typing import Optional
from inspect import isfunction
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from selva_core.ext.rotary_embeddings import apply_rope
from selva_core.model.low_level import MLP, ChannelLastConv1d, ConvMLP
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# training will crash without these contiguous calls and the CUDNN limitation
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
# unresolved at the time of writing
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
b, h, n, d_head = out.shape
out = out.permute(0, 2, 1, 3) # Shape: (b, n, h, d_head)
# Using reshape, which can handle non-contiguous tensors by copying if necessary
out = out.reshape(b, n, h * d_head) # Shape: (b, n, h * d_head)
# Ensure the final output is contiguous, similar to the original code's intent
out = out.contiguous()
return out
def attention_debug(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
layer_idx: int = -1) -> None:
# training will crash without these contiguous calls and the CUDNN limitation
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
# unresolved at the time of writing
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
# out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# debug attn map
import math
scale_factor = 1 / math.sqrt(q.size(-1))
L, S = q.size(-2), k.size(-2)
attn_bias = torch.zeros(q.shape[0], q.shape[1], L, S, dtype=q.dtype, device=q.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
torch.save(attn_weight.clone().cpu(), f'./debug_attn_weight_layer{layer_idx}_unnorm.pt')
# normalize
attn_weight = torch.softmax(attn_weight, dim=-1)
torch.save(attn_weight.clone().cpu(), f'./debug_attn_weight_layer{layer_idx}.pt')
def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
def default(val, d):
return val if val is not None else (d() if isfunction(d) else d)
b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
return attn_mask
class SelfAttention(nn.Module):
def __init__(self, dim: int, nheads: int):
super().__init__()
self.dim = dim
self.nheads = nheads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = nn.RMSNorm(dim // nheads)
self.k_norm = nn.RMSNorm(dim // nheads)
self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
h=nheads,
d=dim // nheads,
j=3)
def pre_attention(
self, x: torch.Tensor,
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x: batch_size * n_tokens * n_channels
qkv = self.qkv(x)
q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
q = q.squeeze(-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
if rot is not None:
q = apply_rope(q, rot)
k = apply_rope(k, rot)
return q, k, v
def forward(
self,
x: torch.Tensor, # batch_size * n_tokens * n_channels
q_mask: Optional[torch.Tensor] = None,
k_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, v, k = self.pre_attention(x)
if q_mask is not None or k_mask is not None:
attn_mask = create_mask(q.shape, k.shape, q.device,
q_mask=q_mask, k_mask=k_mask)
else:
attn_mask = None
out = attention(q, k, v, attn_mask)
return out
class CrossAttention(nn.Module):
def __init__(self, dim: int, nheads: int):
"""
Args:
dim (int): Input dimension.
nheads (int): Number of attention heads.
Attributes:
q_proj (Linear): Linear transformation for the query.
kv_proj (Linear): Linear transformation for the key and value.
q_norm (RMSNorm): Layer normalization for the query.
k_norm (RMSNorm): Layer normalization for the key.
split_into_heads (Rearrange): Rearrange layer to split the input into heads.
"""
super().__init__()
self.dim = dim
self.nheads = nheads
self.q_proj = nn.Linear(dim, dim, bias=True)
self.kv_proj = nn.Linear(dim, dim * 2, bias=True)
self.q_norm = nn.RMSNorm(dim // nheads)
self.k_norm = nn.RMSNorm(dim // nheads)
self.split_q_into_heads = Rearrange('b n (h d) -> b h n d',
h=nheads,
d=dim // nheads)
self.split_kv_into_heads = Rearrange('b n (h d j) -> b h n d j',
h=nheads,
d=dim // nheads,
j=2)
def pre_attention(
self, x: torch.Tensor, c: torch.Tensor,
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x: batch_size * n_tokens * n_channels
# c: batch_size * n_cond_tokens * n_channels
q = self.q_proj(x)
kv = self.kv_proj(c)
q = self.split_q_into_heads(q)
k, v = self.split_kv_into_heads(kv).chunk(2, dim=-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
if rot is not None:
q = apply_rope(q, rot)
return q, k, v
def forward(
self,
x: torch.Tensor, # batch_size * n_tokens * n_channels
c: torch.Tensor, # batch_size * n_cond_tokens * n_channels
context_mask: Optional[torch.Tensor] = None,
rot: Optional[torch.Tensor] = None
) -> torch.Tensor:
q, k, v = self.pre_attention(x, c, rot)
if context_mask is not None:
attn_mask = create_mask(q.shape, k.shape, q.device, k_mask=context_mask)
else:
attn_mask = None
out = attention(q, k, v, attn_mask)
return out
class MMCrossAttentionBlock(nn.Module):
def __init__(self,
dim: int,
nhead: int,
mlp_ratio: float = 4.0,
# pre_only: bool = False,
kernel_size: int = 7,
padding: int = 3,
residual: bool = True):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True)
self.attn = CrossAttention(dim, nhead)
if kernel_size == 1:
self.linear1 = nn.Linear(dim, dim)
else:
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=True)
if kernel_size == 1:
self.ffn = MLP(dim, int(dim * mlp_ratio))
else:
self.ffn = ConvMLP(dim,
int(dim * mlp_ratio),
kernel_size=kernel_size,
padding=padding)
self.residual = residual
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
# x: BS * N * D
# cond: BS * D
# if self.pre_only:
# (shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
# gate_msa = shift_mlp = scale_mlp = gate_mlp = None
# else:
# (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
# gate_mlp) = modulation.chunk(6, dim=-1)
# x = self.norm1(x)
q, k, v = self.attn.pre_attention(x, c, rot)
return (q, k, v)
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor):
# if self.pre_only:
# return x
# (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
if self.residual:
x = x + self.norm1(self.linear1(attn_out)) # * gate_msa
# https://github.com/haidog-yaqub/EzAudio/blob/2eb0bd90013584c6e28a6c14ec28b935f1e78de5/src/models/blocks.py#L158
# https://github.com/huggingface/diffusers/blob/07dd6f8c0e267662f62c39cd8334c2b5d157ab39/src/diffusers/models/transformers/transformer_flux.py#L170
# https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/attention.py#L274
else:
x = self.norm1(self.linear1(attn_out))
r = self.norm2(x)
x = x + self.ffn(r)
return x
def forward(self, x: torch.Tensor, cond: torch.Tensor,
rot: Optional[torch.Tensor],
x_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# x: BS * N * D
# cond: BS * D
q, k, v = self.pre_attention(x, cond, rot)
if x_mask is not None or context_mask is not None:
attn_mask = create_mask(q.shape, k.shape, q.device, q_mask=x_mask, k_mask=context_mask)
else:
attn_mask = None
attn_out = attention(q, k, v, attn_mask=attn_mask)
x = self.post_attention(x, attn_out)
return x
class MMDitSingleBlock(nn.Module):
def __init__(self,
dim: int,
nhead: int,
mlp_ratio: float = 4.0,
pre_only: bool = False,
kernel_size: int = 7,
padding: int = 3):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = SelfAttention(dim, nhead)
self.pre_only = pre_only
if pre_only:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
else:
if kernel_size == 1:
self.linear1 = nn.Linear(dim, dim)
else:
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
if kernel_size == 1:
self.ffn = MLP(dim, int(dim * mlp_ratio))
else:
self.ffn = ConvMLP(dim,
int(dim * mlp_ratio),
kernel_size=kernel_size,
padding=padding)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
# x: BS * N * D
# cond: BS * D
modulation = self.adaLN_modulation(c)
if self.pre_only:
(shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
gate_msa = shift_mlp = scale_mlp = gate_mlp = None
else:
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
gate_mlp) = modulation.chunk(6, dim=-1)
x = modulate(self.norm1(x), shift_msa, scale_msa)
q, k, v = self.attn.pre_attention(x, rot)
return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]):
if self.pre_only:
return x
(gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
x = x + self.linear1(attn_out) * gate_msa
r = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + self.ffn(r) * gate_mlp
return x
def forward(self, x: torch.Tensor, cond: torch.Tensor,
rot: Optional[torch.Tensor]) -> torch.Tensor:
# x: BS * N * D
# cond: BS * D
x_qkv, x_conditions = self.pre_attention(x, cond, rot)
attn_out = attention(*x_qkv)
x = self.post_attention(x, attn_out, x_conditions)
return x
class JointBlock(nn.Module):
def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
super().__init__()
self.pre_only = pre_only
self.latent_block = MMDitSingleBlock(dim,
nhead,
mlp_ratio,
pre_only=False,
kernel_size=3,
padding=1)
self.clip_block = MMDitSingleBlock(dim,
nhead,
mlp_ratio,
pre_only=pre_only,
kernel_size=3,
padding=1)
self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# latent: BS * N1 * D
# clip_f: BS * N2 * D
# c: BS * (1/N) * D
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
latent_len = latent.shape[1]
clip_len = clip_f.shape[1]
text_len = text_f.shape[1]
joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
attn_out = attention(*joint_qkv)
x_attn_out = attn_out[:, :latent_len]
c_attn_out = attn_out[:, latent_len:latent_len + clip_len]
t_attn_out = attn_out[:, latent_len + clip_len:]
latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
if not self.pre_only:
clip_f = self.clip_block.post_attention(clip_f, c_attn_out, c_mod)
text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
return latent, clip_f, text_f
def forward_debug(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
clip_rot: torch.Tensor,
layer_idx: int = -1,
) -> None:
# latent: BS * N1 * D
# clip_f: BS * N2 * D
# c: BS * (1/N) * D
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
latent_len = latent.shape[1]
clip_len = clip_f.shape[1]
text_len = text_f.shape[1]
joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
attn_out = attention_debug(*joint_qkv, layer_idx=layer_idx)
return None
class FinalBlock(nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
def forward(self, latent, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
latent = modulate(self.norm(latent), shift, scale)
latent = self.conv(latent)
return latent