chore: vendor selva_core from jnwnlee/selva@d7d40a9

Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 15:18:09 +02:00
parent 762b19fd3a
commit 6bc3fd6443
106 changed files with 11323 additions and 0 deletions
View File
+49
View File
@@ -0,0 +1,49 @@
import torch
import torch.nn as nn
# https://github.com/facebookresearch/DiT
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, dim, frequency_embedding_size, max_period):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
self.dim = dim
self.max_period = max_period
assert dim % 2 == 0, 'dim must be even.'
with torch.autocast('cuda', enabled=False):
self.freqs = nn.Buffer(
1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
frequency_embedding_size)),
persistent=False)
freq_scale = 10000 / max_period
self.freqs = freq_scale * self.freqs
def timestep_embedding(self, t):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
args = t[:, None].float() * self.freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t).to(t.dtype)
t_emb = self.mlp(t_freq)
return t_emb
+87
View File
@@ -0,0 +1,87 @@
import logging
from typing import Callable, Optional
import torch
from torchdiffeq import odeint
log = logging.getLogger()
# Partially from https://github.com/gle-bellier/flow-matching
class FlowMatching:
def __init__(self, min_sigma: float = 0.0, inference_mode='euler', num_steps: int = 25,
target: str = 'v'):
# inference_mode: 'euler' or 'adaptive'
# num_steps: number of steps in the euler inference mode
super().__init__()
self.min_sigma = min_sigma
self.inference_mode = inference_mode
self.num_steps = num_steps
self.target = target
# self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma)
assert self.inference_mode in ['euler', 'adaptive']
if self.inference_mode == 'adaptive' and num_steps > 0:
log.info('The number of steps is ignored in adaptive inference mode ')
def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor,
t: torch.Tensor) -> torch.Tensor:
# which is psi_t(x), eq 22 in flow matching for generative models
t = t[:, None, None].expand_as(x0)
return (1 - (1 - self.min_sigma) * t) * x0 + t * x1
def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor,
xt: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> torch.Tensor:
# return the mean error without reducing the batch dimension
reduce_dim = list(range(1, len(predicted_v.shape)))
if self.target == 'v':
target_v = x1 - (1 - self.min_sigma) * x0
return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
elif self.target == 'x1':
if xt is None or t is None:
raise ValueError("xt and t must be provided when target is 'x1'")
t = t[:, None, None].expand_as(x0)
predicted_x1 = xt + (1 - t) * predicted_v - self.min_sigma * x0
return (predicted_x1 - x1).pow(2).mean(dim=reduce_dim)
else:
raise ValueError(f"Unknown target: {self.target}. Supported targets are 'v' and 'x1'.")
def get_x0_xt_c(
self,
x1: torch.Tensor,
t: torch.Tensor,
Cs: list[torch.Tensor],
generator: Optional[torch.Generator] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x0 = torch.empty_like(x1).normal_(generator=generator)
xt = self.get_conditional_flow(x0, x1, t)
return x0, x1, xt, Cs
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
return self.run_t0_to_t1(fn, x1, 1, 0)
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
return self.run_t0_to_t1(fn, x0, 0, 1)
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
# fn: a function that takes (t, x) and returns the direction x0->x1
if self.inference_mode == 'adaptive':
return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype))
elif self.inference_mode == 'euler':
x = x0
steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1)
for ti, t in enumerate(steps[:-1]):
flow = fn(t, x)
next_t = steps[ti + 1]
dt = next_t - t
x = x + dt * flow
# print(f"DEBUG timestep {ti=}")
# if ti == 11:
# print(f'{ti=} quit!!!!!!!!!!!!')
# quit();
return x
+95
View File
@@ -0,0 +1,95 @@
import torch
from torch import nn
from torch.nn import functional as F
class ChannelLastConv1d(nn.Conv1d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = super().forward(x)
x = x.permute(0, 2, 1)
return x
# https://github.com/Stability-AI/sd3-ref
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ConvMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
kernel_size: int = 3,
padding: int = 1,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w2 = ChannelLastConv1d(hidden_dim,
dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w3 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
+475
View File
@@ -0,0 +1,475 @@
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from selva_core.ext.rotary_embeddings import compute_rope_rotations
from selva_core.model.embeddings import TimestepEmbedder
from selva_core.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from selva_core.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
log = logging.getLogger()
@dataclass
class PreprocessedConditions:
clip_f: torch.Tensor
sync_f: torch.Tensor
text_f: torch.Tensor
clip_f_c: torch.Tensor
text_f_c: torch.Tensor
# Partially from https://github.com/facebookresearch/DiT
class MMAudio(nn.Module):
def __init__(self,
*,
latent_dim: int,
clip_dim: int,
sync_dim: int,
text_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
clip_seq_len: int,
sync_seq_len: int,
text_seq_len: int = 77,
latent_mean: Optional[torch.Tensor] = None,
latent_std: Optional[torch.Tensor] = None,
empty_string_feat: Optional[torch.Tensor] = None,
v2: bool = False) -> None:
super().__init__()
self.v2 = v2
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._clip_seq_len = clip_seq_len
self._sync_seq_len = sync_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
if v2:
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
nn.SiLU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.clip_input_proj = nn.Sequential(
nn.Linear(clip_dim, hidden_dim),
nn.SiLU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
)
self.sync_input_proj = nn.Sequential(
ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3),
nn.SiLU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
nn.SiLU(),
MLP(hidden_dim, hidden_dim * 4),
)
else:
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.clip_input_proj = nn.Sequential(
nn.Linear(clip_dim, hidden_dim),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
)
self.sync_input_proj = nn.Sequential(
ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim)
self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim)
self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
# each synchformer output segment has 8 feature frames
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim)))
self.final_layer = FinalBlock(hidden_dim, latent_dim)
if v2:
self.t_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=hidden_dim,
max_period=1)
else:
self.t_embed = TimestepEmbedder(hidden_dim,
frequency_embedding_size=256,
max_period=10000)
self.joint_blocks = nn.ModuleList([
JointBlock(hidden_dim,
num_heads,
mlp_ratio=mlp_ratio,
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
if latent_mean is None:
# these values are not meant to be used
# if you don't provide mean/std here, we should load them later from a checkpoint
assert latent_std is None
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
else:
assert latent_std is not None
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}'
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False)
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False)
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True)
self.initialize_weights()
self.initialize_rotations()
def initialize_rotations(self):
base_freq = 1.0
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device=self.device)
clip_rot = compute_rope_rotations(self._clip_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq * self._latent_seq_len /
self._clip_seq_len,
device=self.device)
self.latent_rot = nn.Buffer(latent_rot, persistent=False)
self.clip_rot = nn.Buffer(clip_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
self._clip_seq_len = clip_seq_len
self._sync_seq_len = sync_seq_len
self.initialize_rotations()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
# empty string feat shall be initialized by a CLIP encoder
nn.init.constant_(self.sync_pos_emb, 0)
nn.init.constant_(self.empty_clip_feat, 0)
nn.init.constant_(self.empty_sync_feat, 0)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# return (x - self.latent_mean) / self.latent_std
return x.sub_(self.latent_mean).div_(self.latent_std)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
# return x * self.latent_std + self.latent_mean
return x.mul_(self.latent_std).add_(self.latent_mean)
def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor,
text_f: torch.Tensor) -> PreprocessedConditions:
"""
cache computations that do not depend on the latent/time step
i.e., the features are reused over steps during inference
"""
assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}'
assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}'
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = clip_f.shape[0]
# B * num_segments (24) * 8 * 768
num_sync_segments = self._sync_seq_len // 8
sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb
sync_f = sync_f.flatten(1, 2) # (B, VN, D)
# extend vf to match x
clip_f = self.clip_input_proj(clip_f) # (B, VN, D)
sync_f = self.sync_input_proj(sync_f) # (B, VN, D)
text_f = self.text_input_proj(text_f) # (B, VN, D)
# upsample the sync features to match the audio
sync_f = sync_f.transpose(1, 2) # (B, D, VN)
sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact')
sync_f = sync_f.transpose(1, 2) # (B, N, D)
# get conditional features from the clip side
clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D)
text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D)
return PreprocessedConditions(clip_f=clip_f,
sync_f=sync_f,
text_f=text_f,
clip_f_c=clip_f_c,
text_f_c=text_f_c)
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,
conditions: PreprocessedConditions) -> torch.Tensor:
"""
for non-cacheable computations
"""
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
clip_f = conditions.clip_f
sync_f = conditions.sync_f
text_f = conditions.text_f
clip_f_c = conditions.clip_f_c
text_f_c = conditions.text_f_c
latent = self.audio_input_proj(latent) # (B, N, D)
global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D)
global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D)
extended_c = global_c + sync_f
for block in self.joint_blocks:
# for i, block in enumerate(self.joint_blocks):
# # debug attention weight attn map
# block.forward_debug(latent.clone(), clip_f.clone(), text_f.clone(),
# global_c.clone(), extended_c.clone(),
# self.latent_rot, self.clip_rot,
# layer_idx=i+1)
latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c,
self.latent_rot, self.clip_rot) # (B, N, D)
for block in self.fused_blocks:
latent = block(latent, extended_c, self.latent_rot)
flow = self.final_layer(latent, global_c) # (B, N, out_dim), remove t
return flow
def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor,
text_f: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
latent: (B, N, C)
vf: (B, T, C_V)
t: (B,)
"""
conditions = self.preprocess_conditions(clip_f, sync_f, text_f)
flow = self.predict_flow(latent, t, conditions)
return flow
def get_empty_string_sequence(self, bs: int) -> torch.Tensor:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
def get_empty_clip_sequence(self, bs: int) -> torch.Tensor:
return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1)
def get_empty_sync_sequence(self, bs: int) -> torch.Tensor:
return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1)
def get_empty_conditions(
self,
bs: int,
*,
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
if negative_text_features is not None:
empty_text = negative_text_features
else:
empty_text = self.get_empty_string_sequence(1)
empty_clip = self.get_empty_clip_sequence(1)
empty_sync = self.get_empty_sync_sequence(1)
conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text)
conditions.clip_f = conditions.clip_f.expand(bs, -1, -1)
conditions.sync_f = conditions.sync_f.expand(bs, -1, -1)
conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1)
if negative_text_features is None:
conditions.text_f = conditions.text_f.expand(bs, -1, -1)
conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
return conditions
def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions,
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor:
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
if cfg_strength < 1.0:
return self.predict_flow(latent, t, conditions)
else:
return (cfg_strength * self.predict_flow(latent, t, conditions) +
(1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions))
def load_weights(self, src_dict) -> None:
if 't_embed.freqs' in src_dict:
del src_dict['t_embed.freqs']
if 'latent_rot' in src_dict:
del src_dict['latent_rot']
if 'clip_rot' in src_dict:
del src_dict['clip_rot']
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return self.latent_mean.device
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
@property
def clip_seq_len(self) -> int:
return self._clip_seq_len
@property
def sync_seq_len(self) -> int:
return self._sync_seq_len
def small_16k(**kwargs) -> MMAudio:
num_heads = 7
return MMAudio(latent_dim=20,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=250,
clip_seq_len=64,
sync_seq_len=192,
**kwargs)
def small_44k(**kwargs) -> MMAudio:
num_heads = 7
return MMAudio(latent_dim=40,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=345,
clip_seq_len=64,
sync_seq_len=192,
**kwargs)
def medium_44k(**kwargs) -> MMAudio:
num_heads = 14
return MMAudio(latent_dim=40,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=345,
clip_seq_len=64,
sync_seq_len=192,
**kwargs)
def large_44k(**kwargs) -> MMAudio:
num_heads = 14
return MMAudio(latent_dim=40,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=21,
fused_depth=14,
num_heads=num_heads,
latent_seq_len=345,
clip_seq_len=64,
sync_seq_len=192,
**kwargs)
def large_44k_v2(**kwargs) -> MMAudio:
num_heads = 14
return MMAudio(latent_dim=40,
clip_dim=1024,
sync_dim=768,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=21,
fused_depth=14,
num_heads=num_heads,
latent_seq_len=345,
clip_seq_len=64,
sync_seq_len=192,
v2=True,
**kwargs)
def get_my_mmaudio(name: str, **kwargs) -> MMAudio:
if name == 'small_16k':
return small_16k(**kwargs)
if name == 'small_44k':
return small_44k(**kwargs)
if name == 'medium_44k':
return medium_44k(**kwargs)
if name == 'large_44k':
return large_44k(**kwargs)
if name == 'large_44k_v2':
return large_44k_v2(**kwargs)
raise ValueError(f'Unknown model name: {name}')
if __name__ == '__main__':
network = get_my_mmaudio('small_16k')
# print the number of parameters in terms of millions
num_params = sum(p.numel() for p in network.parameters()) / 1e6
print(f'Number of parameters: {num_params:.2f}M')
+188
View File
@@ -0,0 +1,188 @@
from typing import Optional, Union, List, Tuple, Any, Mapping
from dataclasses import dataclass
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from selva_core.model.text_synchformer import TextSynchformer
from selva_core.utils.transforms import generate_multiple_segments
@dataclass
class PreprocessedConditions:
sync_f: torch.Tensor
sync_f_c: torch.Tensor
text_f: torch.Tensor
text_f_c: torch.Tensor
text_mask: torch.Tensor
class TextSynch(TextSynchformer):
def __init__(self,
*,
text_dim: int,
video_seq_len: int = 192,
max_text_seq_len: int = 512,
empty_string_feat: torch.Tensor = None,
num_sup_text_tokens: int = 5,
sync_batch_size_multiplier: Union[int, float] = -1,
xattn_depth: int = 1,
) -> None:
super().__init__(
text_dim=text_dim,
max_text_seq_len=max_text_seq_len,
xattn_depth=xattn_depth,
)
self._video_seq_len = video_seq_len
self.num_sup_text_tokens = num_sup_text_tokens
self.sync_batch_size_multiplier = sync_batch_size_multiplier
if num_sup_text_tokens > 0:
self.sup_text_feat = nn.Parameter(torch.zeros(num_sup_text_tokens, self.text_dim),
requires_grad=True)
if empty_string_feat is None:
empty_string_feat = torch.zeros((1, text_dim))
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.initialize_weights()
def update_seq_lengths(self, video_seq_len: int) -> None:
self._video_seq_len = video_seq_len
def get_empty_string_sequence(self, bs: int) -> torch.Tensor:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
def get_sup_text_sequence(self, bs: int) -> torch.Tensor:
if self.num_sup_text_tokens <= 0:
raise ValueError(f'supplementary text tokens not enabled as {self.num_sup_text_tokens=}')
return self.sup_text_feat.expand(bs, -1, -1)
def prepend_sup_text_tokens(self, text_f: torch.Tensor, text_mask: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
if self.num_sup_text_tokens <= 0:
return text_f, text_mask
bs = text_f.shape[0]
sup_text_f = self.get_sup_text_sequence(bs) # (B, S, D)
sup_text_mask = torch.ones(bs, sup_text_f.shape[1],
device=text_mask.device, dtype=text_mask.dtype) # (B, S)
text_f = torch.cat([sup_text_f, text_f], dim=1)
text_mask = torch.cat([sup_text_mask, text_mask], dim=1)
return text_f, text_mask
def encode_video_with_sync(self, x: torch.Tensor, text_f: torch.Tensor,
text_mask: torch.Tensor) -> torch.Tensor:
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
# partition the video
segment_size = 16
step_size = 8
x = generate_multiple_segments(x, segment_size, step_size) # (B, S, T, C, H, W)
num_segments = x.shape[1]
outputs = []
if self.sync_batch_size_multiplier <= 0:
batch_size = b
else:
batch_size = int(b * self.sync_batch_size_multiplier)
x = einops.rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
for i in range(0, b * num_segments, batch_size):
start_idx = i // num_segments
end_idx = min((i + batch_size - 1) // num_segments + 1, b)
text_f_batch = text_f[start_idx:end_idx]
text_mask_batch = text_mask[start_idx:end_idx]
current_total_batch_size = min(batch_size, b * num_segments - i)
repeats = torch.zeros(end_idx - start_idx, dtype=torch.long, device=x.device)
for j in range(current_total_batch_size):
original_batch_idx = (i + j) // num_segments
repeats[original_batch_idx - start_idx] += 1
text_f_batch_repeated = torch.repeat_interleave(text_f_batch, repeats, dim=0)
text_mask_batch_repeated = torch.repeat_interleave(text_mask_batch, repeats, dim=0)
outputs.append(self.forward_vfeat(
x[i:i + batch_size],
text_f=text_f_batch_repeated,
text_mask=text_mask_batch_repeated
))
x = torch.cat(outputs, dim=0)
x = einops.rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
return x
def encode_audio_with_sync(self, x: torch.Tensor, text_f: torch.Tensor,
text_mask: torch.Tensor) -> torch.Tensor:
return self.forward_afeat(
x, text_f=text_f, text_mask=text_mask
)
def load_synchformer_state_dict(self, src_dict: dict):
self.load_state_dict(src_dict, strict=True)
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return self.empty_string_feat.device
@property
def dtype(self) -> torch.dtype:
return self.empty_string_feat.dtype
@property
def video_seq_len(self) -> int:
return self._video_seq_len
@property
def audio_seq_len(self) -> int:
return self._audio_seq_len
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
target_keys = (['vfeat_extractor'] if self.video else []) \
+ (['afeat_extractor'] if self.audio else []) \
+ ['text_proj', 'synch_text_cross_blocks',
'sup_text_feat', 'empty_string_feat']
# discard all entries except vfeat_extractor / afeat_extractor
sd = {k: v for k, v in sd.items() if any(k.startswith(tk)
for tk in target_keys)}
return nn.Module.load_state_dict(self, sd, strict=strict)
def depth1(**kwargs) -> TextSynch:
return TextSynch(text_dim=768,
video_seq_len=192,
max_text_seq_len=512,
xattn_depth=1,
**kwargs)
def get_my_textsynch(name: str, **kwargs) -> TextSynch:
if name.startswith('depth1'):
return depth1(**kwargs)
else:
raise ValueError(f'Unknown model name: {name}')
if __name__ == '__main__':
network = get_my_textsynch('depth1')
# print the number of parameters in terms of millions
num_params = sum(p.numel() for p in network.parameters()) / 1e6
print(f'Number of parameters: {num_params:.2f}M')
torch.compile(network.encode_video_with_sync)
print(f"Compiled encode_video_with_sync")
torch.compile(network.predict_flow)
print(f"Compiled predict_flow")
torch.compile(network.preprocess_conditions)
print(f"Compiled preprocess_conditions:")
torch.compile(network.forward)
print(f"Compiled forward:")
+62
View File
@@ -0,0 +1,62 @@
import dataclasses
import math
@dataclasses.dataclass
class SequenceConfig:
# general
duration: float
# audio
sampling_rate: int
spectrogram_frame_rate: int
latent_downsample_rate: int = 2
# visual
clip_frame_rate: int = 8
sync_frame_rate: int = 25
sync_num_frames_per_segment: int = 16
sync_step_size: int = 8
sync_downsample_rate: int = 2
@property
def num_audio_frames(self) -> int:
# we need an integer number of latents
return self.latent_seq_len * self.spectrogram_frame_rate * self.latent_downsample_rate
@property
def latent_seq_len(self) -> int:
return int(
math.ceil(self.duration * self.sampling_rate / self.spectrogram_frame_rate /
self.latent_downsample_rate))
@property
def clip_seq_len(self) -> int:
return int(self.duration * self.clip_frame_rate)
@property
def sync_seg_len(self) -> int:
num_frames = self.duration * self.sync_frame_rate
num_segments = (num_frames - self.sync_num_frames_per_segment) // self.sync_step_size + 1
return int(num_segments)
@property
def sync_seq_len(self) -> int:
return int(self.sync_seg_len * self.sync_num_frames_per_segment / self.sync_downsample_rate)
CONFIG_16K = SequenceConfig(duration=8.0, sampling_rate=16000, spectrogram_frame_rate=256)
CONFIG_44K = SequenceConfig(duration=8.0, sampling_rate=44100, spectrogram_frame_rate=512)
if __name__ == '__main__':
assert CONFIG_16K.latent_seq_len == 250
assert CONFIG_16K.clip_seq_len == 64
assert CONFIG_16K.sync_seq_len == 192
assert CONFIG_16K.num_audio_frames == 128000
assert CONFIG_44K.latent_seq_len == 345
assert CONFIG_44K.clip_seq_len == 64
assert CONFIG_44K.sync_seq_len == 192
assert CONFIG_44K.num_audio_frames == 353280
print('Passed')
+199
View File
@@ -0,0 +1,199 @@
import logging
from typing import Any, Mapping
import einops
import torch
from torch import nn
from selva_core.ext.synchformer.motionformer import MotionFormer, SpatialTransformerEncoderLayer, BaseEncoderLayer
from selva_core.ext.synchformer.astransformer import AST
from selva_core.model.transformer_layers import (MMCrossAttentionBlock)
from selva_core.model.low_level import MLP
class ExtendedMotionFormer(MotionFormer):
"""Extended MotionFormer with additional methods for text synchronization."""
def forward_segments_without_aggregation(self, x, orig_shape: tuple) -> tuple[torch.Tensor, torch.Tensor]:
"""
Extract features without spatial-temporal aggregation.
Args:
x: Input tensor of shape (BS, C, T, H, W) where S is the number of segments
orig_shape: Original shape tuple (B, S, C, T, H, W)
Returns:
Tuple of (features, mask) where features are of shape (B*S, D, t, h, w)
"""
x, x_mask = self.forward_features(x)
assert self.extract_features
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
x = x[:, 1:, :] # without the CLS token for efficiency
x = self.norm(x)
x = self.pre_logits(x)
if self.factorize_space_time:
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
return x, x_mask
def spatiotemporal_aggregation(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
"""
Apply spatial-temporal aggregation to features.
Args:
x: Features tensor of shape (B*S, D, t, h, w)
x_mask: Mask tensor
Returns:
Aggregated features of shape (B*S, D) or (B*S, t, D)
"""
if self.factorize_space_time:
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
x = self.temp_attn_agg(x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
return x
class TextSynchformer(nn.Module):
def __init__(self, video: bool = True, audio: bool = False,
text_dim: int = 1024, max_text_seq_len: int = 512, xattn_depth: int = 1):
super().__init__()
self.video = video
self.audio = audio
self.text_dim = text_dim
self.max_text_seq_len = max_text_seq_len
if not video and not audio:
raise ValueError('At least one of video or audio should be True.')
# Use ExtendedMotionFormer directly instead of inheriting from Synchformer
if self.video:
self.vfeat_extractor = ExtendedMotionFormer(
extract_features=True,
factorize_space_time=True,
agg_space_module='TransformerEncoderLayer',
agg_time_module='torch.nn.Identity',
add_global_repr=False
)
if self.audio:
self.afeat_extractor = AST(
extract_features=True,
max_spec_t=66,
factorize_freq_time=True,
agg_freq_module='TransformerEncoderLayer',
agg_time_module='torch.nn.Identity',
add_global_repr=False
)
# Get embedding dimensions from the video feature extractor
if self.video:
self.embed_dim = self.vfeat_extractor.embed_dim
self.num_heads = self.vfeat_extractor.num_heads
self.mlp_ratio = self.vfeat_extractor.mlp_ratio
else:
# Default values if no video
self.embed_dim = 768
self.num_heads = 12
self.mlp_ratio = 4
self.text_proj = nn.Sequential(
nn.Linear(self.text_dim, self.embed_dim),
nn.SiLU(),
MLP(self.embed_dim, self.embed_dim * 4)
)
self.synch_text_cross_blocks = nn.ModuleList([
MMCrossAttentionBlock(self.embed_dim, self.num_heads,
mlp_ratio=self.mlp_ratio,
kernel_size=1, padding=0,
residual=True)
for _ in range(xattn_depth)
])
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.text_proj.apply(_basic_init)
self.synch_text_cross_blocks.apply(_basic_init)
for block in self.synch_text_cross_blocks:
nn.init.constant_(block.norm1.weight, 0.0)
nn.init.constant_(block.norm1.bias, 0.0)
nn.init.constant_(block.ffn.w2.weight, 0.0)
def forward(self, data, text_features):
video, audio = None, None
if self.video and self.audio:
video, audio = data
elif self.video:
video = data
elif self.audio:
audio = data
if self.video and video is not None:
video = self.forward_vfeat(video, text_features)
if self.audio and audio is not None:
audio = self.forward_afeat(audio, text_features)
if self.video and self.audio:
return video, audio
elif self.video:
return video
else:
return audio
def forward_vfeat(self, vis, text_f, text_mask):
B, S, Tv, C, H, W = vis.shape
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
# Flatten for processing
orig_shape = (B, S, C, Tv, H, W)
vis = einops.rearrange(vis, 'B S C Tv H W -> (B S) C Tv H W') # vis.view(B * S, C, Tv, H, W)
vis, vis_mask = self.vfeat_extractor.forward_segments_without_aggregation(
vis, orig_shape # B*S D t h w , BS t h w
)
text_f = self.text_proj(text_f) # (B, text_dim) -> (B, embed_dim)
BS, D, t, h, w = vis.shape
vis = einops.rearrange(vis, '(B S) D t h w -> B (S t h w) D', B=B, S=S)
vis_mask = einops.rearrange(vis_mask, '(B S) t h w -> B (S t h w)', B=B, S=S) \
if vis_mask is not None else None
for block in self.synch_text_cross_blocks:
vis = block(vis, text_f, rot=None, x_mask=vis_mask, context_mask=text_mask)
vis = einops.rearrange(vis, 'B (S t h w) D -> (B S) D t h w', B=B, S=S, D=D, t=t, h=h, w=w)
vis_mask = einops.rearrange(vis_mask, 'B (S t h w) -> (B S) t h w', B=B, S=S, t=t, h=h, w=w) \
if vis_mask is not None else None
vis = self.vfeat_extractor.spatiotemporal_aggregation(
vis, vis_mask
)
vis = vis.view(B, S, *vis.shape[1:])
return vis
def forward_afeat(self, aud):
"""Forward audio features."""
raise NotImplementedError("Audio feature extraction is not implemented in TextSynchformer.")
# B, S, F, Ta = aud.shape
# aud = aud.permute(0, 1, 3, 2) # (B, S, Ta, F)
# aud, _ = self.afeat_extractor(aud)
# return aud
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
target_keys = (['vfeat_extractor'] if self.video else []) \
+ (['afeat_extractor'] if self.audio else []) \
+ ['text_proj', 'synch_text_cross_blocks']
# discard all entries except vfeat_extractor / afeat_extractor
sd = {k: v for k, v in sd.items() if any(k.startswith(tk)
for tk in target_keys)}
return super().load_state_dict(sd, strict)
+426
View File
@@ -0,0 +1,426 @@
from typing import Optional
from inspect import isfunction
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from selva_core.ext.rotary_embeddings import apply_rope
from selva_core.model.low_level import MLP, ChannelLastConv1d, ConvMLP
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# training will crash without these contiguous calls and the CUDNN limitation
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
# unresolved at the time of writing
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
b, h, n, d_head = out.shape
out = out.permute(0, 2, 1, 3) # Shape: (b, n, h, d_head)
# Using reshape, which can handle non-contiguous tensors by copying if necessary
out = out.reshape(b, n, h * d_head) # Shape: (b, n, h * d_head)
# Ensure the final output is contiguous, similar to the original code's intent
out = out.contiguous()
return out
def attention_debug(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
layer_idx: int = -1) -> None:
# training will crash without these contiguous calls and the CUDNN limitation
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
# unresolved at the time of writing
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
# out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# debug attn map
import math
scale_factor = 1 / math.sqrt(q.size(-1))
L, S = q.size(-2), k.size(-2)
attn_bias = torch.zeros(q.shape[0], q.shape[1], L, S, dtype=q.dtype, device=q.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
torch.save(attn_weight.clone().cpu(), f'./debug_attn_weight_layer{layer_idx}_unnorm.pt')
# normalize
attn_weight = torch.softmax(attn_weight, dim=-1)
torch.save(attn_weight.clone().cpu(), f'./debug_attn_weight_layer{layer_idx}.pt')
def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
def default(val, d):
return val if val is not None else (d() if isfunction(d) else d)
b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
return attn_mask
class SelfAttention(nn.Module):
def __init__(self, dim: int, nheads: int):
super().__init__()
self.dim = dim
self.nheads = nheads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = nn.RMSNorm(dim // nheads)
self.k_norm = nn.RMSNorm(dim // nheads)
self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
h=nheads,
d=dim // nheads,
j=3)
def pre_attention(
self, x: torch.Tensor,
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x: batch_size * n_tokens * n_channels
qkv = self.qkv(x)
q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
q = q.squeeze(-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
if rot is not None:
q = apply_rope(q, rot)
k = apply_rope(k, rot)
return q, k, v
def forward(
self,
x: torch.Tensor, # batch_size * n_tokens * n_channels
q_mask: Optional[torch.Tensor] = None,
k_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, v, k = self.pre_attention(x)
if q_mask is not None or k_mask is not None:
attn_mask = create_mask(q.shape, k.shape, q.device,
q_mask=q_mask, k_mask=k_mask)
else:
attn_mask = None
out = attention(q, k, v, attn_mask)
return out
class CrossAttention(nn.Module):
def __init__(self, dim: int, nheads: int):
"""
Args:
dim (int): Input dimension.
nheads (int): Number of attention heads.
Attributes:
q_proj (Linear): Linear transformation for the query.
kv_proj (Linear): Linear transformation for the key and value.
q_norm (RMSNorm): Layer normalization for the query.
k_norm (RMSNorm): Layer normalization for the key.
split_into_heads (Rearrange): Rearrange layer to split the input into heads.
"""
super().__init__()
self.dim = dim
self.nheads = nheads
self.q_proj = nn.Linear(dim, dim, bias=True)
self.kv_proj = nn.Linear(dim, dim * 2, bias=True)
self.q_norm = nn.RMSNorm(dim // nheads)
self.k_norm = nn.RMSNorm(dim // nheads)
self.split_q_into_heads = Rearrange('b n (h d) -> b h n d',
h=nheads,
d=dim // nheads)
self.split_kv_into_heads = Rearrange('b n (h d j) -> b h n d j',
h=nheads,
d=dim // nheads,
j=2)
def pre_attention(
self, x: torch.Tensor, c: torch.Tensor,
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x: batch_size * n_tokens * n_channels
# c: batch_size * n_cond_tokens * n_channels
q = self.q_proj(x)
kv = self.kv_proj(c)
q = self.split_q_into_heads(q)
k, v = self.split_kv_into_heads(kv).chunk(2, dim=-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
if rot is not None:
q = apply_rope(q, rot)
return q, k, v
def forward(
self,
x: torch.Tensor, # batch_size * n_tokens * n_channels
c: torch.Tensor, # batch_size * n_cond_tokens * n_channels
context_mask: Optional[torch.Tensor] = None,
rot: Optional[torch.Tensor] = None
) -> torch.Tensor:
q, k, v = self.pre_attention(x, c, rot)
if context_mask is not None:
attn_mask = create_mask(q.shape, k.shape, q.device, k_mask=context_mask)
else:
attn_mask = None
out = attention(q, k, v, attn_mask)
return out
class MMCrossAttentionBlock(nn.Module):
def __init__(self,
dim: int,
nhead: int,
mlp_ratio: float = 4.0,
# pre_only: bool = False,
kernel_size: int = 7,
padding: int = 3,
residual: bool = True):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True)
self.attn = CrossAttention(dim, nhead)
if kernel_size == 1:
self.linear1 = nn.Linear(dim, dim)
else:
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=True)
if kernel_size == 1:
self.ffn = MLP(dim, int(dim * mlp_ratio))
else:
self.ffn = ConvMLP(dim,
int(dim * mlp_ratio),
kernel_size=kernel_size,
padding=padding)
self.residual = residual
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
# x: BS * N * D
# cond: BS * D
# if self.pre_only:
# (shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
# gate_msa = shift_mlp = scale_mlp = gate_mlp = None
# else:
# (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
# gate_mlp) = modulation.chunk(6, dim=-1)
# x = self.norm1(x)
q, k, v = self.attn.pre_attention(x, c, rot)
return (q, k, v)
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor):
# if self.pre_only:
# return x
# (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
if self.residual:
x = x + self.norm1(self.linear1(attn_out)) # * gate_msa
# https://github.com/haidog-yaqub/EzAudio/blob/2eb0bd90013584c6e28a6c14ec28b935f1e78de5/src/models/blocks.py#L158
# https://github.com/huggingface/diffusers/blob/07dd6f8c0e267662f62c39cd8334c2b5d157ab39/src/diffusers/models/transformers/transformer_flux.py#L170
# https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/attention.py#L274
else:
x = self.norm1(self.linear1(attn_out))
r = self.norm2(x)
x = x + self.ffn(r)
return x
def forward(self, x: torch.Tensor, cond: torch.Tensor,
rot: Optional[torch.Tensor],
x_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# x: BS * N * D
# cond: BS * D
q, k, v = self.pre_attention(x, cond, rot)
if x_mask is not None or context_mask is not None:
attn_mask = create_mask(q.shape, k.shape, q.device, q_mask=x_mask, k_mask=context_mask)
else:
attn_mask = None
attn_out = attention(q, k, v, attn_mask=attn_mask)
x = self.post_attention(x, attn_out)
return x
class MMDitSingleBlock(nn.Module):
def __init__(self,
dim: int,
nhead: int,
mlp_ratio: float = 4.0,
pre_only: bool = False,
kernel_size: int = 7,
padding: int = 3):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = SelfAttention(dim, nhead)
self.pre_only = pre_only
if pre_only:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
else:
if kernel_size == 1:
self.linear1 = nn.Linear(dim, dim)
else:
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
if kernel_size == 1:
self.ffn = MLP(dim, int(dim * mlp_ratio))
else:
self.ffn = ConvMLP(dim,
int(dim * mlp_ratio),
kernel_size=kernel_size,
padding=padding)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
# x: BS * N * D
# cond: BS * D
modulation = self.adaLN_modulation(c)
if self.pre_only:
(shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
gate_msa = shift_mlp = scale_mlp = gate_mlp = None
else:
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
gate_mlp) = modulation.chunk(6, dim=-1)
x = modulate(self.norm1(x), shift_msa, scale_msa)
q, k, v = self.attn.pre_attention(x, rot)
return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]):
if self.pre_only:
return x
(gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
x = x + self.linear1(attn_out) * gate_msa
r = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + self.ffn(r) * gate_mlp
return x
def forward(self, x: torch.Tensor, cond: torch.Tensor,
rot: Optional[torch.Tensor]) -> torch.Tensor:
# x: BS * N * D
# cond: BS * D
x_qkv, x_conditions = self.pre_attention(x, cond, rot)
attn_out = attention(*x_qkv)
x = self.post_attention(x, attn_out, x_conditions)
return x
class JointBlock(nn.Module):
def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
super().__init__()
self.pre_only = pre_only
self.latent_block = MMDitSingleBlock(dim,
nhead,
mlp_ratio,
pre_only=False,
kernel_size=3,
padding=1)
self.clip_block = MMDitSingleBlock(dim,
nhead,
mlp_ratio,
pre_only=pre_only,
kernel_size=3,
padding=1)
self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# latent: BS * N1 * D
# clip_f: BS * N2 * D
# c: BS * (1/N) * D
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
latent_len = latent.shape[1]
clip_len = clip_f.shape[1]
text_len = text_f.shape[1]
joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
attn_out = attention(*joint_qkv)
x_attn_out = attn_out[:, :latent_len]
c_attn_out = attn_out[:, latent_len:latent_len + clip_len]
t_attn_out = attn_out[:, latent_len + clip_len:]
latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
if not self.pre_only:
clip_f = self.clip_block.post_attention(clip_f, c_attn_out, c_mod)
text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
return latent, clip_f, text_f
def forward_debug(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
clip_rot: torch.Tensor,
layer_idx: int = -1,
) -> None:
# latent: BS * N1 * D
# clip_f: BS * N2 * D
# c: BS * (1/N) * D
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
latent_len = latent.shape[1]
clip_len = clip_f.shape[1]
text_len = text_f.shape[1]
joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
attn_out = attention_debug(*joint_qkv, layer_idx=layer_idx)
return None
class FinalBlock(nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
def forward(self, latent, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
latent = modulate(self.norm(latent), shift, scale)
latent = self.conv(latent)
return latent
View File
+46
View File
@@ -0,0 +1,46 @@
from typing import Optional
import numpy as np
import torch
class DiagonalGaussianDistribution:
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, rng: Optional[torch.Generator] = None):
# x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
r = torch.empty_like(self.mean).normal_(generator=rng)
x = self.mean + self.std * r
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
else:
return 0.5 * (torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
+18
View File
@@ -0,0 +1,18 @@
import torch
from selva_core.utils.misc import instantiate_from_config
from selva_core.model.networks_video_enc import TextSynch as TextSynchVideoEnc
from selva_core.model.networks_generator import MMAudio
_MODEL_ZOO = (TextSynchVideoEnc, MMAudio)
def create_model_from_factory(factory_path: str, name: str, **kwargs) -> torch.nn.Module:
"""
Dynamically imports and calls a model factory function.
"""
params = {'name': name, **kwargs}
model = instantiate_from_config(factory_path, params)
assert isinstance(model, _MODEL_ZOO), f"Model {type(model)} is not a valid model type."
return model
+192
View File
@@ -0,0 +1,192 @@
from typing import Literal, Optional
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
from transformers import T5TokenizerFast, T5EncoderModel
from selva_core.ext.autoencoder import AutoEncoderModule
from selva_core.ext.mel_converter import get_mel_converter
from selva_core.ext.synchformer import Synchformer
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
from selva_core.utils.transforms import generate_multiple_segments
def patch_clip(clip_model):
# a hack to make it output last hidden states
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
def new_encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
return F.normalize(x, dim=-1) if normalize else x
clip_model.encode_text = new_encode_text.__get__(clip_model)
return clip_model
class FeaturesUtils(nn.Module):
def __init__(
self,
*,
tod_vae_ckpt: Optional[str] = None,
bigvgan_vocoder_ckpt: Optional[str] = None,
synchformer_ckpt: Optional[str] = None,
enable_conditions: bool = True,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
if enable_conditions:
self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',
return_transform=False)
self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
self.clip_model = patch_clip(self.clip_model)
self.tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
self.synchformer = Synchformer(video=True, audio=False)
self.synchformer.load_state_dict(
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
self.text_encoder_t5 = T5EncoderModel.from_pretrained('google/flan-t5-base')
self.tokenizer_t5 = T5TokenizerFast.from_pretrained('google/flan-t5-base')
else:
self.clip_model = None
self.synchformer = None
self.tokenizer_clip = None
self.text_encoder_t5 = None
self.tokenizer_t5 = None
if tod_vae_ckpt is not None:
self.mel_converter = get_mel_converter(mode)
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
mode=mode,
need_vae_encoder=need_vae_encoder)
else:
self.tod = None
def compile(self):
if self.clip_model is not None:
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
if self.synchformer is not None:
self.synchformer = torch.compile(self.synchformer)
self.synchformer.forward_vfeat = torch.compile(self.synchformer.forward_vfeat)
if self.text_encoder_t5 is not None:
self.text_encoder_t5.forward = torch.compile(self.text_encoder_t5.forward)
self.decode = torch.compile(self.decode)
self.vocode = torch.compile(self.vocode)
def train(self, mode: bool) -> None:
return super().train(False)
@torch.inference_mode()
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.clip_model is not None, 'CLIP is not loaded'
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
assert c == 3 and h == 384 and w == 384
x = self.clip_preprocess(x)
x = rearrange(x, 'b t c h w -> (b t) c h w')
outputs = []
if batch_size < 0:
batch_size = b * t
for i in range(0, b * t, batch_size):
outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True))
x = torch.cat(outputs, dim=0)
# x = self.clip_model.encode_image(x, normalize=True)
x = rearrange(x, '(b t) d -> b t d', b=b)
return x
@torch.inference_mode()
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.synchformer is not None, 'Synchformer is not loaded'
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
# partition the video
segment_size = 16
step_size = 8
x = generate_multiple_segments(x, segment_size, step_size) # (B, S, T, C, H, W)
num_segments = x.shape[1]
outputs = []
if batch_size < 0:
batch_size = b
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
for i in range(0, b * num_segments, batch_size):
outputs.append(self.synchformer.forward_vfeat(x[i:i + batch_size]))
x = torch.cat(outputs, dim=0)
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
return x
@torch.inference_mode()
def encode_text_clip(self, text: list[str]) -> torch.Tensor:
assert self.clip_model is not None, 'CLIP is not loaded'
assert self.tokenizer_clip is not None, 'Tokenizer is not loaded'
# x: (B, L)
tokens = self.tokenizer_clip(text).to(self.device)
return self.clip_model.encode_text(tokens, normalize=True)
@torch.inference_mode()
def encode_text_t5(self, text: list[str]) -> torch.Tensor:
device = self.text_encoder_t5.device
batch = self.tokenizer_t5(
text,
max_length=self.tokenizer_t5.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
device
)
encoder_hidden_states = self.text_encoder_t5(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state # (B, L, D)
boolean_encoder_mask = (attention_mask == 1).to(device) # (B, L)
return encoder_hidden_states, boolean_encoder_mask
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
assert self.tod is not None, 'VAE is not loaded'
# x: (B * L)
mel = self.mel_converter(x)
dist = self.tod.encode(mel)
return dist
@torch.inference_mode()
def vocode(self, mel: torch.Tensor) -> torch.Tensor:
assert self.tod is not None, 'VAE is not loaded'
return self.tod.vocode(mel)
@torch.inference_mode()
def decode(self, z: torch.Tensor) -> torch.Tensor:
assert self.tod is not None, 'VAE is not loaded'
return self.tod.decode(z.transpose(1, 2))
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@@ -0,0 +1,39 @@
import logging
log = logging.getLogger()
def get_parameter_groups(model, cfg, print_log=False):
"""
Assign different weight decays and learning rates to different parameters.
Returns a parameter group which can be passed to the optimizer.
"""
weight_decay = cfg.weight_decay
base_lr = cfg.learning_rate
params = []
# inspired by detectron2
memo = set()
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Avoid duplicating parameters
if param in memo:
continue
memo.add(param)
if name.startswith('module'):
name = name[7:]
params.append(param)
parameter_groups = [
{
'params': params,
'lr': base_lr,
'weight_decay': weight_decay
},
]
return parameter_groups
+12
View File
@@ -0,0 +1,12 @@
from typing import Optional
import torch
def log_normal_sample(x: torch.Tensor,
generator: Optional[torch.Generator] = None,
m: float = 0.0,
s: float = 1.0) -> torch.Tensor:
bs = x.shape[0]
s = torch.randn(bs, device=x.device, generator=generator) * s + m
return torch.sigmoid(s)