chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# https://github.com/facebookresearch/DiT
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, frequency_embedding_size, max_period):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim),
|
||||
)
|
||||
self.dim = dim
|
||||
self.max_period = max_period
|
||||
assert dim % 2 == 0, 'dim must be even.'
|
||||
|
||||
with torch.autocast('cuda', enabled=False):
|
||||
self.freqs = nn.Buffer(
|
||||
1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
|
||||
frequency_embedding_size)),
|
||||
persistent=False)
|
||||
freq_scale = 10000 / max_period
|
||||
self.freqs = freq_scale * self.freqs
|
||||
|
||||
def timestep_embedding(self, t):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
|
||||
args = t[:, None].float() * self.freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t).to(t.dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
@@ -0,0 +1,87 @@
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torchdiffeq import odeint
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
# Partially from https://github.com/gle-bellier/flow-matching
|
||||
class FlowMatching:
|
||||
|
||||
def __init__(self, min_sigma: float = 0.0, inference_mode='euler', num_steps: int = 25,
|
||||
target: str = 'v'):
|
||||
# inference_mode: 'euler' or 'adaptive'
|
||||
# num_steps: number of steps in the euler inference mode
|
||||
super().__init__()
|
||||
self.min_sigma = min_sigma
|
||||
self.inference_mode = inference_mode
|
||||
self.num_steps = num_steps
|
||||
self.target = target
|
||||
|
||||
# self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma)
|
||||
|
||||
assert self.inference_mode in ['euler', 'adaptive']
|
||||
if self.inference_mode == 'adaptive' and num_steps > 0:
|
||||
log.info('The number of steps is ignored in adaptive inference mode ')
|
||||
|
||||
def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor,
|
||||
t: torch.Tensor) -> torch.Tensor:
|
||||
# which is psi_t(x), eq 22 in flow matching for generative models
|
||||
t = t[:, None, None].expand_as(x0)
|
||||
return (1 - (1 - self.min_sigma) * t) * x0 + t * x1
|
||||
|
||||
def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor,
|
||||
xt: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# return the mean error without reducing the batch dimension
|
||||
reduce_dim = list(range(1, len(predicted_v.shape)))
|
||||
if self.target == 'v':
|
||||
target_v = x1 - (1 - self.min_sigma) * x0
|
||||
return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
|
||||
elif self.target == 'x1':
|
||||
if xt is None or t is None:
|
||||
raise ValueError("xt and t must be provided when target is 'x1'")
|
||||
t = t[:, None, None].expand_as(x0)
|
||||
predicted_x1 = xt + (1 - t) * predicted_v - self.min_sigma * x0
|
||||
return (predicted_x1 - x1).pow(2).mean(dim=reduce_dim)
|
||||
else:
|
||||
raise ValueError(f"Unknown target: {self.target}. Supported targets are 'v' and 'x1'.")
|
||||
|
||||
def get_x0_xt_c(
|
||||
self,
|
||||
x1: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
Cs: list[torch.Tensor],
|
||||
generator: Optional[torch.Generator] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x0 = torch.empty_like(x1).normal_(generator=generator)
|
||||
|
||||
xt = self.get_conditional_flow(x0, x1, t)
|
||||
return x0, x1, xt, Cs
|
||||
|
||||
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
|
||||
return self.run_t0_to_t1(fn, x1, 1, 0)
|
||||
|
||||
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
|
||||
return self.run_t0_to_t1(fn, x0, 0, 1)
|
||||
|
||||
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
|
||||
# fn: a function that takes (t, x) and returns the direction x0->x1
|
||||
|
||||
if self.inference_mode == 'adaptive':
|
||||
return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype))
|
||||
elif self.inference_mode == 'euler':
|
||||
x = x0
|
||||
steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1)
|
||||
for ti, t in enumerate(steps[:-1]):
|
||||
flow = fn(t, x)
|
||||
next_t = steps[ti + 1]
|
||||
dt = next_t - t
|
||||
x = x + dt * flow
|
||||
# print(f"DEBUG timestep {ti=}")
|
||||
# if ti == 11:
|
||||
# print(f'{ti=} quit!!!!!!!!!!!!')
|
||||
# quit();
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ChannelLastConv1d(nn.Conv1d):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = super().forward(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/sd3-ref
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class ConvMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ChannelLastConv1d(dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.w2 = ChannelLastConv1d(hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.w3 = ChannelLastConv1d(dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
@@ -0,0 +1,475 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from selva_core.ext.rotary_embeddings import compute_rope_rotations
|
||||
from selva_core.model.embeddings import TimestepEmbedder
|
||||
from selva_core.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from selva_core.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessedConditions:
|
||||
clip_f: torch.Tensor
|
||||
sync_f: torch.Tensor
|
||||
text_f: torch.Tensor
|
||||
clip_f_c: torch.Tensor
|
||||
text_f_c: torch.Tensor
|
||||
|
||||
|
||||
# Partially from https://github.com/facebookresearch/DiT
|
||||
class MMAudio(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
latent_dim: int,
|
||||
clip_dim: int,
|
||||
sync_dim: int,
|
||||
text_dim: int,
|
||||
hidden_dim: int,
|
||||
depth: int,
|
||||
fused_depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
latent_seq_len: int,
|
||||
clip_seq_len: int,
|
||||
sync_seq_len: int,
|
||||
text_seq_len: int = 77,
|
||||
latent_mean: Optional[torch.Tensor] = None,
|
||||
latent_std: Optional[torch.Tensor] = None,
|
||||
empty_string_feat: Optional[torch.Tensor] = None,
|
||||
v2: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.v2 = v2
|
||||
self.latent_dim = latent_dim
|
||||
self._latent_seq_len = latent_seq_len
|
||||
self._clip_seq_len = clip_seq_len
|
||||
self._sync_seq_len = sync_seq_len
|
||||
self._text_seq_len = text_seq_len
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
if v2:
|
||||
self.audio_input_proj = nn.Sequential(
|
||||
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
|
||||
nn.SiLU(),
|
||||
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
|
||||
)
|
||||
|
||||
self.clip_input_proj = nn.Sequential(
|
||||
nn.Linear(clip_dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
self.sync_input_proj = nn.Sequential(
|
||||
ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3),
|
||||
nn.SiLU(),
|
||||
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
self.text_input_proj = nn.Sequential(
|
||||
nn.Linear(text_dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
MLP(hidden_dim, hidden_dim * 4),
|
||||
)
|
||||
else:
|
||||
self.audio_input_proj = nn.Sequential(
|
||||
ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
|
||||
nn.SELU(),
|
||||
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
|
||||
)
|
||||
|
||||
self.clip_input_proj = nn.Sequential(
|
||||
nn.Linear(clip_dim, hidden_dim),
|
||||
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
self.sync_input_proj = nn.Sequential(
|
||||
ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3),
|
||||
nn.SELU(),
|
||||
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
self.text_input_proj = nn.Sequential(
|
||||
nn.Linear(text_dim, hidden_dim),
|
||||
MLP(hidden_dim, hidden_dim * 4),
|
||||
)
|
||||
|
||||
self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
|
||||
# each synchformer output segment has 8 feature frames
|
||||
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim)))
|
||||
|
||||
self.final_layer = FinalBlock(hidden_dim, latent_dim)
|
||||
|
||||
if v2:
|
||||
self.t_embed = TimestepEmbedder(hidden_dim,
|
||||
frequency_embedding_size=hidden_dim,
|
||||
max_period=1)
|
||||
else:
|
||||
self.t_embed = TimestepEmbedder(hidden_dim,
|
||||
frequency_embedding_size=256,
|
||||
max_period=10000)
|
||||
self.joint_blocks = nn.ModuleList([
|
||||
JointBlock(hidden_dim,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth)
|
||||
])
|
||||
|
||||
self.fused_blocks = nn.ModuleList([
|
||||
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
|
||||
for i in range(fused_depth)
|
||||
])
|
||||
|
||||
if latent_mean is None:
|
||||
# these values are not meant to be used
|
||||
# if you don't provide mean/std here, we should load them later from a checkpoint
|
||||
assert latent_std is None
|
||||
latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
|
||||
latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan'))
|
||||
else:
|
||||
assert latent_std is not None
|
||||
assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}'
|
||||
if empty_string_feat is None:
|
||||
empty_string_feat = torch.zeros((text_seq_len, text_dim))
|
||||
self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False)
|
||||
self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False)
|
||||
|
||||
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True)
|
||||
|
||||
self.initialize_weights()
|
||||
self.initialize_rotations()
|
||||
|
||||
def initialize_rotations(self):
|
||||
base_freq = 1.0
|
||||
latent_rot = compute_rope_rotations(self._latent_seq_len,
|
||||
self.hidden_dim // self.num_heads,
|
||||
10000,
|
||||
freq_scaling=base_freq,
|
||||
device=self.device)
|
||||
clip_rot = compute_rope_rotations(self._clip_seq_len,
|
||||
self.hidden_dim // self.num_heads,
|
||||
10000,
|
||||
freq_scaling=base_freq * self._latent_seq_len /
|
||||
self._clip_seq_len,
|
||||
device=self.device)
|
||||
|
||||
self.latent_rot = nn.Buffer(latent_rot, persistent=False)
|
||||
self.clip_rot = nn.Buffer(clip_rot, persistent=False)
|
||||
|
||||
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
|
||||
self._latent_seq_len = latent_seq_len
|
||||
self._clip_seq_len = clip_seq_len
|
||||
self._sync_seq_len = sync_seq_len
|
||||
self.initialize_rotations()
|
||||
|
||||
def initialize_weights(self):
|
||||
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers in DiT blocks:
|
||||
for block in self.joint_blocks:
|
||||
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
|
||||
for block in self.fused_blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.conv.weight, 0)
|
||||
nn.init.constant_(self.final_layer.conv.bias, 0)
|
||||
|
||||
# empty string feat shall be initialized by a CLIP encoder
|
||||
nn.init.constant_(self.sync_pos_emb, 0)
|
||||
nn.init.constant_(self.empty_clip_feat, 0)
|
||||
nn.init.constant_(self.empty_sync_feat, 0)
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# return (x - self.latent_mean) / self.latent_std
|
||||
return x.sub_(self.latent_mean).div_(self.latent_std)
|
||||
|
||||
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# return x * self.latent_std + self.latent_mean
|
||||
return x.mul_(self.latent_std).add_(self.latent_mean)
|
||||
|
||||
def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor,
|
||||
text_f: torch.Tensor) -> PreprocessedConditions:
|
||||
"""
|
||||
cache computations that do not depend on the latent/time step
|
||||
i.e., the features are reused over steps during inference
|
||||
"""
|
||||
assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}'
|
||||
assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}'
|
||||
assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
|
||||
|
||||
bs = clip_f.shape[0]
|
||||
|
||||
# B * num_segments (24) * 8 * 768
|
||||
num_sync_segments = self._sync_seq_len // 8
|
||||
sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb
|
||||
sync_f = sync_f.flatten(1, 2) # (B, VN, D)
|
||||
|
||||
# extend vf to match x
|
||||
clip_f = self.clip_input_proj(clip_f) # (B, VN, D)
|
||||
sync_f = self.sync_input_proj(sync_f) # (B, VN, D)
|
||||
text_f = self.text_input_proj(text_f) # (B, VN, D)
|
||||
|
||||
# upsample the sync features to match the audio
|
||||
sync_f = sync_f.transpose(1, 2) # (B, D, VN)
|
||||
sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact')
|
||||
sync_f = sync_f.transpose(1, 2) # (B, N, D)
|
||||
|
||||
# get conditional features from the clip side
|
||||
clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D)
|
||||
text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D)
|
||||
|
||||
return PreprocessedConditions(clip_f=clip_f,
|
||||
sync_f=sync_f,
|
||||
text_f=text_f,
|
||||
clip_f_c=clip_f_c,
|
||||
text_f_c=text_f_c)
|
||||
|
||||
def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,
|
||||
conditions: PreprocessedConditions) -> torch.Tensor:
|
||||
"""
|
||||
for non-cacheable computations
|
||||
"""
|
||||
assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
|
||||
|
||||
clip_f = conditions.clip_f
|
||||
sync_f = conditions.sync_f
|
||||
text_f = conditions.text_f
|
||||
clip_f_c = conditions.clip_f_c
|
||||
text_f_c = conditions.text_f_c
|
||||
|
||||
latent = self.audio_input_proj(latent) # (B, N, D)
|
||||
global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D)
|
||||
|
||||
global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D)
|
||||
extended_c = global_c + sync_f
|
||||
|
||||
for block in self.joint_blocks:
|
||||
# for i, block in enumerate(self.joint_blocks):
|
||||
# # debug attention weight attn map
|
||||
# block.forward_debug(latent.clone(), clip_f.clone(), text_f.clone(),
|
||||
# global_c.clone(), extended_c.clone(),
|
||||
# self.latent_rot, self.clip_rot,
|
||||
# layer_idx=i+1)
|
||||
latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c,
|
||||
self.latent_rot, self.clip_rot) # (B, N, D)
|
||||
|
||||
for block in self.fused_blocks:
|
||||
latent = block(latent, extended_c, self.latent_rot)
|
||||
|
||||
flow = self.final_layer(latent, global_c) # (B, N, out_dim), remove t
|
||||
return flow
|
||||
|
||||
def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor,
|
||||
text_f: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
latent: (B, N, C)
|
||||
vf: (B, T, C_V)
|
||||
t: (B,)
|
||||
"""
|
||||
conditions = self.preprocess_conditions(clip_f, sync_f, text_f)
|
||||
flow = self.predict_flow(latent, t, conditions)
|
||||
return flow
|
||||
|
||||
def get_empty_string_sequence(self, bs: int) -> torch.Tensor:
|
||||
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
|
||||
|
||||
def get_empty_clip_sequence(self, bs: int) -> torch.Tensor:
|
||||
return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1)
|
||||
|
||||
def get_empty_sync_sequence(self, bs: int) -> torch.Tensor:
|
||||
return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1)
|
||||
|
||||
def get_empty_conditions(
|
||||
self,
|
||||
bs: int,
|
||||
*,
|
||||
negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
|
||||
if negative_text_features is not None:
|
||||
empty_text = negative_text_features
|
||||
else:
|
||||
empty_text = self.get_empty_string_sequence(1)
|
||||
|
||||
empty_clip = self.get_empty_clip_sequence(1)
|
||||
empty_sync = self.get_empty_sync_sequence(1)
|
||||
conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text)
|
||||
conditions.clip_f = conditions.clip_f.expand(bs, -1, -1)
|
||||
conditions.sync_f = conditions.sync_f.expand(bs, -1, -1)
|
||||
conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1)
|
||||
if negative_text_features is None:
|
||||
conditions.text_f = conditions.text_f.expand(bs, -1, -1)
|
||||
conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
|
||||
|
||||
return conditions
|
||||
|
||||
def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions,
|
||||
empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor:
|
||||
t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype)
|
||||
|
||||
if cfg_strength < 1.0:
|
||||
return self.predict_flow(latent, t, conditions)
|
||||
else:
|
||||
return (cfg_strength * self.predict_flow(latent, t, conditions) +
|
||||
(1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions))
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
if 't_embed.freqs' in src_dict:
|
||||
del src_dict['t_embed.freqs']
|
||||
if 'latent_rot' in src_dict:
|
||||
del src_dict['latent_rot']
|
||||
if 'clip_rot' in src_dict:
|
||||
del src_dict['clip_rot']
|
||||
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.latent_mean.device
|
||||
|
||||
@property
|
||||
def latent_seq_len(self) -> int:
|
||||
return self._latent_seq_len
|
||||
|
||||
@property
|
||||
def clip_seq_len(self) -> int:
|
||||
return self._clip_seq_len
|
||||
|
||||
@property
|
||||
def sync_seq_len(self) -> int:
|
||||
return self._sync_seq_len
|
||||
|
||||
|
||||
def small_16k(**kwargs) -> MMAudio:
|
||||
num_heads = 7
|
||||
return MMAudio(latent_dim=20,
|
||||
clip_dim=1024,
|
||||
sync_dim=768,
|
||||
text_dim=1024,
|
||||
hidden_dim=64 * num_heads,
|
||||
depth=12,
|
||||
fused_depth=8,
|
||||
num_heads=num_heads,
|
||||
latent_seq_len=250,
|
||||
clip_seq_len=64,
|
||||
sync_seq_len=192,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def small_44k(**kwargs) -> MMAudio:
|
||||
num_heads = 7
|
||||
return MMAudio(latent_dim=40,
|
||||
clip_dim=1024,
|
||||
sync_dim=768,
|
||||
text_dim=1024,
|
||||
hidden_dim=64 * num_heads,
|
||||
depth=12,
|
||||
fused_depth=8,
|
||||
num_heads=num_heads,
|
||||
latent_seq_len=345,
|
||||
clip_seq_len=64,
|
||||
sync_seq_len=192,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def medium_44k(**kwargs) -> MMAudio:
|
||||
num_heads = 14
|
||||
return MMAudio(latent_dim=40,
|
||||
clip_dim=1024,
|
||||
sync_dim=768,
|
||||
text_dim=1024,
|
||||
hidden_dim=64 * num_heads,
|
||||
depth=12,
|
||||
fused_depth=8,
|
||||
num_heads=num_heads,
|
||||
latent_seq_len=345,
|
||||
clip_seq_len=64,
|
||||
sync_seq_len=192,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def large_44k(**kwargs) -> MMAudio:
|
||||
num_heads = 14
|
||||
return MMAudio(latent_dim=40,
|
||||
clip_dim=1024,
|
||||
sync_dim=768,
|
||||
text_dim=1024,
|
||||
hidden_dim=64 * num_heads,
|
||||
depth=21,
|
||||
fused_depth=14,
|
||||
num_heads=num_heads,
|
||||
latent_seq_len=345,
|
||||
clip_seq_len=64,
|
||||
sync_seq_len=192,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def large_44k_v2(**kwargs) -> MMAudio:
|
||||
num_heads = 14
|
||||
return MMAudio(latent_dim=40,
|
||||
clip_dim=1024,
|
||||
sync_dim=768,
|
||||
text_dim=1024,
|
||||
hidden_dim=64 * num_heads,
|
||||
depth=21,
|
||||
fused_depth=14,
|
||||
num_heads=num_heads,
|
||||
latent_seq_len=345,
|
||||
clip_seq_len=64,
|
||||
sync_seq_len=192,
|
||||
v2=True,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def get_my_mmaudio(name: str, **kwargs) -> MMAudio:
|
||||
if name == 'small_16k':
|
||||
return small_16k(**kwargs)
|
||||
if name == 'small_44k':
|
||||
return small_44k(**kwargs)
|
||||
if name == 'medium_44k':
|
||||
return medium_44k(**kwargs)
|
||||
if name == 'large_44k':
|
||||
return large_44k(**kwargs)
|
||||
if name == 'large_44k_v2':
|
||||
return large_44k_v2(**kwargs)
|
||||
|
||||
raise ValueError(f'Unknown model name: {name}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
network = get_my_mmaudio('small_16k')
|
||||
|
||||
# print the number of parameters in terms of millions
|
||||
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
||||
print(f'Number of parameters: {num_params:.2f}M')
|
||||
@@ -0,0 +1,188 @@
|
||||
from typing import Optional, Union, List, Tuple, Any, Mapping
|
||||
from dataclasses import dataclass
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from selva_core.model.text_synchformer import TextSynchformer
|
||||
from selva_core.utils.transforms import generate_multiple_segments
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessedConditions:
|
||||
sync_f: torch.Tensor
|
||||
sync_f_c: torch.Tensor
|
||||
text_f: torch.Tensor
|
||||
text_f_c: torch.Tensor
|
||||
text_mask: torch.Tensor
|
||||
|
||||
|
||||
class TextSynch(TextSynchformer):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
text_dim: int,
|
||||
video_seq_len: int = 192,
|
||||
max_text_seq_len: int = 512,
|
||||
empty_string_feat: torch.Tensor = None,
|
||||
num_sup_text_tokens: int = 5,
|
||||
sync_batch_size_multiplier: Union[int, float] = -1,
|
||||
xattn_depth: int = 1,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
text_dim=text_dim,
|
||||
max_text_seq_len=max_text_seq_len,
|
||||
xattn_depth=xattn_depth,
|
||||
)
|
||||
|
||||
self._video_seq_len = video_seq_len
|
||||
self.num_sup_text_tokens = num_sup_text_tokens
|
||||
self.sync_batch_size_multiplier = sync_batch_size_multiplier
|
||||
|
||||
if num_sup_text_tokens > 0:
|
||||
self.sup_text_feat = nn.Parameter(torch.zeros(num_sup_text_tokens, self.text_dim),
|
||||
requires_grad=True)
|
||||
if empty_string_feat is None:
|
||||
empty_string_feat = torch.zeros((1, text_dim))
|
||||
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def update_seq_lengths(self, video_seq_len: int) -> None:
|
||||
self._video_seq_len = video_seq_len
|
||||
|
||||
def get_empty_string_sequence(self, bs: int) -> torch.Tensor:
|
||||
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
|
||||
|
||||
def get_sup_text_sequence(self, bs: int) -> torch.Tensor:
|
||||
if self.num_sup_text_tokens <= 0:
|
||||
raise ValueError(f'supplementary text tokens not enabled as {self.num_sup_text_tokens=}')
|
||||
return self.sup_text_feat.expand(bs, -1, -1)
|
||||
|
||||
def prepend_sup_text_tokens(self, text_f: torch.Tensor, text_mask: torch.Tensor) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.num_sup_text_tokens <= 0:
|
||||
return text_f, text_mask
|
||||
bs = text_f.shape[0]
|
||||
sup_text_f = self.get_sup_text_sequence(bs) # (B, S, D)
|
||||
sup_text_mask = torch.ones(bs, sup_text_f.shape[1],
|
||||
device=text_mask.device, dtype=text_mask.dtype) # (B, S)
|
||||
text_f = torch.cat([sup_text_f, text_f], dim=1)
|
||||
text_mask = torch.cat([sup_text_mask, text_mask], dim=1)
|
||||
return text_f, text_mask
|
||||
|
||||
def encode_video_with_sync(self, x: torch.Tensor, text_f: torch.Tensor,
|
||||
text_mask: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, T, C, H, W) H/W: 384
|
||||
|
||||
b, t, c, h, w = x.shape
|
||||
assert c == 3 and h == 224 and w == 224
|
||||
|
||||
# partition the video
|
||||
segment_size = 16
|
||||
step_size = 8
|
||||
x = generate_multiple_segments(x, segment_size, step_size) # (B, S, T, C, H, W)
|
||||
num_segments = x.shape[1]
|
||||
|
||||
outputs = []
|
||||
if self.sync_batch_size_multiplier <= 0:
|
||||
batch_size = b
|
||||
else:
|
||||
batch_size = int(b * self.sync_batch_size_multiplier)
|
||||
x = einops.rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
||||
for i in range(0, b * num_segments, batch_size):
|
||||
start_idx = i // num_segments
|
||||
end_idx = min((i + batch_size - 1) // num_segments + 1, b)
|
||||
text_f_batch = text_f[start_idx:end_idx]
|
||||
text_mask_batch = text_mask[start_idx:end_idx]
|
||||
current_total_batch_size = min(batch_size, b * num_segments - i)
|
||||
|
||||
repeats = torch.zeros(end_idx - start_idx, dtype=torch.long, device=x.device)
|
||||
for j in range(current_total_batch_size):
|
||||
original_batch_idx = (i + j) // num_segments
|
||||
repeats[original_batch_idx - start_idx] += 1
|
||||
|
||||
text_f_batch_repeated = torch.repeat_interleave(text_f_batch, repeats, dim=0)
|
||||
text_mask_batch_repeated = torch.repeat_interleave(text_mask_batch, repeats, dim=0)
|
||||
|
||||
outputs.append(self.forward_vfeat(
|
||||
x[i:i + batch_size],
|
||||
text_f=text_f_batch_repeated,
|
||||
text_mask=text_mask_batch_repeated
|
||||
))
|
||||
x = torch.cat(outputs, dim=0)
|
||||
x = einops.rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
||||
return x
|
||||
|
||||
def encode_audio_with_sync(self, x: torch.Tensor, text_f: torch.Tensor,
|
||||
text_mask: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_afeat(
|
||||
x, text_f=text_f, text_mask=text_mask
|
||||
)
|
||||
|
||||
def load_synchformer_state_dict(self, src_dict: dict):
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.empty_string_feat.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.empty_string_feat.dtype
|
||||
|
||||
@property
|
||||
def video_seq_len(self) -> int:
|
||||
return self._video_seq_len
|
||||
|
||||
@property
|
||||
def audio_seq_len(self) -> int:
|
||||
return self._audio_seq_len
|
||||
|
||||
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
|
||||
target_keys = (['vfeat_extractor'] if self.video else []) \
|
||||
+ (['afeat_extractor'] if self.audio else []) \
|
||||
+ ['text_proj', 'synch_text_cross_blocks',
|
||||
'sup_text_feat', 'empty_string_feat']
|
||||
# discard all entries except vfeat_extractor / afeat_extractor
|
||||
sd = {k: v for k, v in sd.items() if any(k.startswith(tk)
|
||||
for tk in target_keys)}
|
||||
|
||||
|
||||
return nn.Module.load_state_dict(self, sd, strict=strict)
|
||||
|
||||
|
||||
def depth1(**kwargs) -> TextSynch:
|
||||
return TextSynch(text_dim=768,
|
||||
video_seq_len=192,
|
||||
max_text_seq_len=512,
|
||||
xattn_depth=1,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def get_my_textsynch(name: str, **kwargs) -> TextSynch:
|
||||
if name.startswith('depth1'):
|
||||
return depth1(**kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unknown model name: {name}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
network = get_my_textsynch('depth1')
|
||||
# print the number of parameters in terms of millions
|
||||
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
||||
print(f'Number of parameters: {num_params:.2f}M')
|
||||
|
||||
torch.compile(network.encode_video_with_sync)
|
||||
print(f"Compiled encode_video_with_sync")
|
||||
torch.compile(network.predict_flow)
|
||||
print(f"Compiled predict_flow")
|
||||
torch.compile(network.preprocess_conditions)
|
||||
print(f"Compiled preprocess_conditions:")
|
||||
torch.compile(network.forward)
|
||||
print(f"Compiled forward:")
|
||||
@@ -0,0 +1,62 @@
|
||||
import dataclasses
|
||||
import math
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SequenceConfig:
|
||||
# general
|
||||
duration: float
|
||||
|
||||
# audio
|
||||
sampling_rate: int
|
||||
spectrogram_frame_rate: int
|
||||
latent_downsample_rate: int = 2
|
||||
|
||||
# visual
|
||||
clip_frame_rate: int = 8
|
||||
sync_frame_rate: int = 25
|
||||
sync_num_frames_per_segment: int = 16
|
||||
sync_step_size: int = 8
|
||||
sync_downsample_rate: int = 2
|
||||
|
||||
@property
|
||||
def num_audio_frames(self) -> int:
|
||||
# we need an integer number of latents
|
||||
return self.latent_seq_len * self.spectrogram_frame_rate * self.latent_downsample_rate
|
||||
|
||||
@property
|
||||
def latent_seq_len(self) -> int:
|
||||
return int(
|
||||
math.ceil(self.duration * self.sampling_rate / self.spectrogram_frame_rate /
|
||||
self.latent_downsample_rate))
|
||||
|
||||
@property
|
||||
def clip_seq_len(self) -> int:
|
||||
return int(self.duration * self.clip_frame_rate)
|
||||
|
||||
@property
|
||||
def sync_seg_len(self) -> int:
|
||||
num_frames = self.duration * self.sync_frame_rate
|
||||
num_segments = (num_frames - self.sync_num_frames_per_segment) // self.sync_step_size + 1
|
||||
return int(num_segments)
|
||||
|
||||
@property
|
||||
def sync_seq_len(self) -> int:
|
||||
return int(self.sync_seg_len * self.sync_num_frames_per_segment / self.sync_downsample_rate)
|
||||
|
||||
|
||||
CONFIG_16K = SequenceConfig(duration=8.0, sampling_rate=16000, spectrogram_frame_rate=256)
|
||||
CONFIG_44K = SequenceConfig(duration=8.0, sampling_rate=44100, spectrogram_frame_rate=512)
|
||||
|
||||
if __name__ == '__main__':
|
||||
assert CONFIG_16K.latent_seq_len == 250
|
||||
assert CONFIG_16K.clip_seq_len == 64
|
||||
assert CONFIG_16K.sync_seq_len == 192
|
||||
assert CONFIG_16K.num_audio_frames == 128000
|
||||
|
||||
assert CONFIG_44K.latent_seq_len == 345
|
||||
assert CONFIG_44K.clip_seq_len == 64
|
||||
assert CONFIG_44K.sync_seq_len == 192
|
||||
assert CONFIG_44K.num_audio_frames == 353280
|
||||
|
||||
print('Passed')
|
||||
@@ -0,0 +1,199 @@
|
||||
import logging
|
||||
from typing import Any, Mapping
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from selva_core.ext.synchformer.motionformer import MotionFormer, SpatialTransformerEncoderLayer, BaseEncoderLayer
|
||||
from selva_core.ext.synchformer.astransformer import AST
|
||||
from selva_core.model.transformer_layers import (MMCrossAttentionBlock)
|
||||
from selva_core.model.low_level import MLP
|
||||
|
||||
|
||||
class ExtendedMotionFormer(MotionFormer):
|
||||
"""Extended MotionFormer with additional methods for text synchronization."""
|
||||
|
||||
def forward_segments_without_aggregation(self, x, orig_shape: tuple) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Extract features without spatial-temporal aggregation.
|
||||
Args:
|
||||
x: Input tensor of shape (BS, C, T, H, W) where S is the number of segments
|
||||
orig_shape: Original shape tuple (B, S, C, T, H, W)
|
||||
Returns:
|
||||
Tuple of (features, mask) where features are of shape (B*S, D, t, h, w)
|
||||
"""
|
||||
x, x_mask = self.forward_features(x)
|
||||
|
||||
assert self.extract_features
|
||||
|
||||
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
||||
x = x[:, 1:, :] # without the CLS token for efficiency
|
||||
x = self.norm(x)
|
||||
x = self.pre_logits(x)
|
||||
|
||||
if self.factorize_space_time:
|
||||
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
|
||||
return x, x_mask
|
||||
|
||||
def spatiotemporal_aggregation(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply spatial-temporal aggregation to features.
|
||||
Args:
|
||||
x: Features tensor of shape (B*S, D, t, h, w)
|
||||
x_mask: Mask tensor
|
||||
Returns:
|
||||
Aggregated features of shape (B*S, D) or (B*S, t, D)
|
||||
"""
|
||||
if self.factorize_space_time:
|
||||
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
|
||||
x = self.temp_attn_agg(x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class TextSynchformer(nn.Module):
|
||||
|
||||
def __init__(self, video: bool = True, audio: bool = False,
|
||||
text_dim: int = 1024, max_text_seq_len: int = 512, xattn_depth: int = 1):
|
||||
super().__init__()
|
||||
|
||||
self.video = video
|
||||
self.audio = audio
|
||||
self.text_dim = text_dim
|
||||
self.max_text_seq_len = max_text_seq_len
|
||||
|
||||
if not video and not audio:
|
||||
raise ValueError('At least one of video or audio should be True.')
|
||||
|
||||
# Use ExtendedMotionFormer directly instead of inheriting from Synchformer
|
||||
if self.video:
|
||||
self.vfeat_extractor = ExtendedMotionFormer(
|
||||
extract_features=True,
|
||||
factorize_space_time=True,
|
||||
agg_space_module='TransformerEncoderLayer',
|
||||
agg_time_module='torch.nn.Identity',
|
||||
add_global_repr=False
|
||||
)
|
||||
|
||||
if self.audio:
|
||||
self.afeat_extractor = AST(
|
||||
extract_features=True,
|
||||
max_spec_t=66,
|
||||
factorize_freq_time=True,
|
||||
agg_freq_module='TransformerEncoderLayer',
|
||||
agg_time_module='torch.nn.Identity',
|
||||
add_global_repr=False
|
||||
)
|
||||
|
||||
# Get embedding dimensions from the video feature extractor
|
||||
if self.video:
|
||||
self.embed_dim = self.vfeat_extractor.embed_dim
|
||||
self.num_heads = self.vfeat_extractor.num_heads
|
||||
self.mlp_ratio = self.vfeat_extractor.mlp_ratio
|
||||
else:
|
||||
# Default values if no video
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.mlp_ratio = 4
|
||||
|
||||
self.text_proj = nn.Sequential(
|
||||
nn.Linear(self.text_dim, self.embed_dim),
|
||||
nn.SiLU(),
|
||||
MLP(self.embed_dim, self.embed_dim * 4)
|
||||
)
|
||||
self.synch_text_cross_blocks = nn.ModuleList([
|
||||
MMCrossAttentionBlock(self.embed_dim, self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
kernel_size=1, padding=0,
|
||||
residual=True)
|
||||
for _ in range(xattn_depth)
|
||||
])
|
||||
|
||||
def initialize_weights(self):
|
||||
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.text_proj.apply(_basic_init)
|
||||
self.synch_text_cross_blocks.apply(_basic_init)
|
||||
for block in self.synch_text_cross_blocks:
|
||||
nn.init.constant_(block.norm1.weight, 0.0)
|
||||
nn.init.constant_(block.norm1.bias, 0.0)
|
||||
nn.init.constant_(block.ffn.w2.weight, 0.0)
|
||||
|
||||
|
||||
def forward(self, data, text_features):
|
||||
video, audio = None, None
|
||||
|
||||
if self.video and self.audio:
|
||||
video, audio = data
|
||||
elif self.video:
|
||||
video = data
|
||||
elif self.audio:
|
||||
audio = data
|
||||
|
||||
if self.video and video is not None:
|
||||
video = self.forward_vfeat(video, text_features)
|
||||
if self.audio and audio is not None:
|
||||
audio = self.forward_afeat(audio, text_features)
|
||||
|
||||
if self.video and self.audio:
|
||||
return video, audio
|
||||
elif self.video:
|
||||
return video
|
||||
else:
|
||||
return audio
|
||||
|
||||
def forward_vfeat(self, vis, text_f, text_mask):
|
||||
B, S, Tv, C, H, W = vis.shape
|
||||
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
||||
|
||||
# Flatten for processing
|
||||
orig_shape = (B, S, C, Tv, H, W)
|
||||
vis = einops.rearrange(vis, 'B S C Tv H W -> (B S) C Tv H W') # vis.view(B * S, C, Tv, H, W)
|
||||
vis, vis_mask = self.vfeat_extractor.forward_segments_without_aggregation(
|
||||
vis, orig_shape # B*S D t h w , BS t h w
|
||||
)
|
||||
|
||||
text_f = self.text_proj(text_f) # (B, text_dim) -> (B, embed_dim)
|
||||
|
||||
BS, D, t, h, w = vis.shape
|
||||
vis = einops.rearrange(vis, '(B S) D t h w -> B (S t h w) D', B=B, S=S)
|
||||
vis_mask = einops.rearrange(vis_mask, '(B S) t h w -> B (S t h w)', B=B, S=S) \
|
||||
if vis_mask is not None else None
|
||||
for block in self.synch_text_cross_blocks:
|
||||
vis = block(vis, text_f, rot=None, x_mask=vis_mask, context_mask=text_mask)
|
||||
|
||||
vis = einops.rearrange(vis, 'B (S t h w) D -> (B S) D t h w', B=B, S=S, D=D, t=t, h=h, w=w)
|
||||
vis_mask = einops.rearrange(vis_mask, 'B (S t h w) -> (B S) t h w', B=B, S=S, t=t, h=h, w=w) \
|
||||
if vis_mask is not None else None
|
||||
vis = self.vfeat_extractor.spatiotemporal_aggregation(
|
||||
vis, vis_mask
|
||||
)
|
||||
vis = vis.view(B, S, *vis.shape[1:])
|
||||
|
||||
return vis
|
||||
|
||||
def forward_afeat(self, aud):
|
||||
"""Forward audio features."""
|
||||
raise NotImplementedError("Audio feature extraction is not implemented in TextSynchformer.")
|
||||
# B, S, F, Ta = aud.shape
|
||||
# aud = aud.permute(0, 1, 3, 2) # (B, S, Ta, F)
|
||||
# aud, _ = self.afeat_extractor(aud)
|
||||
# return aud
|
||||
|
||||
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
|
||||
target_keys = (['vfeat_extractor'] if self.video else []) \
|
||||
+ (['afeat_extractor'] if self.audio else []) \
|
||||
+ ['text_proj', 'synch_text_cross_blocks']
|
||||
# discard all entries except vfeat_extractor / afeat_extractor
|
||||
sd = {k: v for k, v in sd.items() if any(k.startswith(tk)
|
||||
for tk in target_keys)}
|
||||
|
||||
|
||||
return super().load_state_dict(sd, strict)
|
||||
|
||||
@@ -0,0 +1,426 @@
|
||||
from typing import Optional
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from selva_core.ext.rotary_embeddings import apply_rope
|
||||
from selva_core.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# training will crash without these contiguous calls and the CUDNN limitation
|
||||
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
|
||||
# unresolved at the time of writing
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.contiguous()
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
# out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
|
||||
b, h, n, d_head = out.shape
|
||||
out = out.permute(0, 2, 1, 3) # Shape: (b, n, h, d_head)
|
||||
# Using reshape, which can handle non-contiguous tensors by copying if necessary
|
||||
out = out.reshape(b, n, h * d_head) # Shape: (b, n, h * d_head)
|
||||
# Ensure the final output is contiguous, similar to the original code's intent
|
||||
out = out.contiguous()
|
||||
return out
|
||||
|
||||
def attention_debug(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
layer_idx: int = -1) -> None:
|
||||
# training will crash without these contiguous calls and the CUDNN limitation
|
||||
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
|
||||
# unresolved at the time of writing
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.contiguous()
|
||||
# out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
# debug attn map
|
||||
import math
|
||||
scale_factor = 1 / math.sqrt(q.size(-1))
|
||||
L, S = q.size(-2), k.size(-2)
|
||||
attn_bias = torch.zeros(q.shape[0], q.shape[1], L, S, dtype=q.dtype, device=q.device)
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias = attn_mask + attn_bias
|
||||
attn_weight = q @ k.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
torch.save(attn_weight.clone().cpu(), f'./debug_attn_weight_layer{layer_idx}_unnorm.pt')
|
||||
# normalize
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
torch.save(attn_weight.clone().cpu(), f'./debug_attn_weight_layer{layer_idx}.pt')
|
||||
|
||||
|
||||
def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
|
||||
def default(val, d):
|
||||
return val if val is not None else (d() if isfunction(d) else d)
|
||||
b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
|
||||
q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
|
||||
k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
|
||||
attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
|
||||
return attn_mask
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, nheads: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.nheads = nheads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.q_norm = nn.RMSNorm(dim // nheads)
|
||||
self.k_norm = nn.RMSNorm(dim // nheads)
|
||||
|
||||
self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
|
||||
h=nheads,
|
||||
d=dim // nheads,
|
||||
j=3)
|
||||
|
||||
def pre_attention(
|
||||
self, x: torch.Tensor,
|
||||
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# x: batch_size * n_tokens * n_channels
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
|
||||
q = q.squeeze(-1)
|
||||
k = k.squeeze(-1)
|
||||
v = v.squeeze(-1)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if rot is not None:
|
||||
q = apply_rope(q, rot)
|
||||
k = apply_rope(k, rot)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # batch_size * n_tokens * n_channels
|
||||
q_mask: Optional[torch.Tensor] = None,
|
||||
k_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
q, v, k = self.pre_attention(x)
|
||||
if q_mask is not None or k_mask is not None:
|
||||
attn_mask = create_mask(q.shape, k.shape, q.device,
|
||||
q_mask=q_mask, k_mask=k_mask)
|
||||
else:
|
||||
attn_mask = None
|
||||
out = attention(q, k, v, attn_mask)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, nheads: int):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
nheads (int): Number of attention heads.
|
||||
|
||||
Attributes:
|
||||
q_proj (Linear): Linear transformation for the query.
|
||||
kv_proj (Linear): Linear transformation for the key and value.
|
||||
q_norm (RMSNorm): Layer normalization for the query.
|
||||
k_norm (RMSNorm): Layer normalization for the key.
|
||||
split_into_heads (Rearrange): Rearrange layer to split the input into heads.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.nheads = nheads
|
||||
|
||||
self.q_proj = nn.Linear(dim, dim, bias=True)
|
||||
self.kv_proj = nn.Linear(dim, dim * 2, bias=True)
|
||||
self.q_norm = nn.RMSNorm(dim // nheads)
|
||||
self.k_norm = nn.RMSNorm(dim // nheads)
|
||||
|
||||
self.split_q_into_heads = Rearrange('b n (h d) -> b h n d',
|
||||
h=nheads,
|
||||
d=dim // nheads)
|
||||
self.split_kv_into_heads = Rearrange('b n (h d j) -> b h n d j',
|
||||
h=nheads,
|
||||
d=dim // nheads,
|
||||
j=2)
|
||||
|
||||
def pre_attention(
|
||||
self, x: torch.Tensor, c: torch.Tensor,
|
||||
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# x: batch_size * n_tokens * n_channels
|
||||
# c: batch_size * n_cond_tokens * n_channels
|
||||
q = self.q_proj(x)
|
||||
kv = self.kv_proj(c)
|
||||
q = self.split_q_into_heads(q)
|
||||
k, v = self.split_kv_into_heads(kv).chunk(2, dim=-1)
|
||||
k = k.squeeze(-1)
|
||||
v = v.squeeze(-1)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if rot is not None:
|
||||
q = apply_rope(q, rot)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # batch_size * n_tokens * n_channels
|
||||
c: torch.Tensor, # batch_size * n_cond_tokens * n_channels
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
rot: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
q, k, v = self.pre_attention(x, c, rot)
|
||||
if context_mask is not None:
|
||||
attn_mask = create_mask(q.shape, k.shape, q.device, k_mask=context_mask)
|
||||
else:
|
||||
attn_mask = None
|
||||
out = attention(q, k, v, attn_mask)
|
||||
return out
|
||||
|
||||
|
||||
class MMCrossAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
nhead: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
# pre_only: bool = False,
|
||||
kernel_size: int = 7,
|
||||
padding: int = 3,
|
||||
residual: bool = True):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn = CrossAttention(dim, nhead)
|
||||
|
||||
if kernel_size == 1:
|
||||
self.linear1 = nn.Linear(dim, dim)
|
||||
else:
|
||||
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=True)
|
||||
|
||||
if kernel_size == 1:
|
||||
self.ffn = MLP(dim, int(dim * mlp_ratio))
|
||||
else:
|
||||
self.ffn = ConvMLP(dim,
|
||||
int(dim * mlp_ratio),
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
|
||||
self.residual = residual
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
|
||||
# x: BS * N * D
|
||||
# cond: BS * D
|
||||
|
||||
# if self.pre_only:
|
||||
# (shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
|
||||
# gate_msa = shift_mlp = scale_mlp = gate_mlp = None
|
||||
# else:
|
||||
# (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
|
||||
# gate_mlp) = modulation.chunk(6, dim=-1)
|
||||
|
||||
# x = self.norm1(x)
|
||||
q, k, v = self.attn.pre_attention(x, c, rot)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor):
|
||||
# if self.pre_only:
|
||||
# return x
|
||||
|
||||
# (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
|
||||
if self.residual:
|
||||
x = x + self.norm1(self.linear1(attn_out)) # * gate_msa
|
||||
# https://github.com/haidog-yaqub/EzAudio/blob/2eb0bd90013584c6e28a6c14ec28b935f1e78de5/src/models/blocks.py#L158
|
||||
# https://github.com/huggingface/diffusers/blob/07dd6f8c0e267662f62c39cd8334c2b5d157ab39/src/diffusers/models/transformers/transformer_flux.py#L170
|
||||
# https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/attention.py#L274
|
||||
else:
|
||||
x = self.norm1(self.linear1(attn_out))
|
||||
r = self.norm2(x)
|
||||
x = x + self.ffn(r)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor,
|
||||
rot: Optional[torch.Tensor],
|
||||
x_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# x: BS * N * D
|
||||
# cond: BS * D
|
||||
q, k, v = self.pre_attention(x, cond, rot)
|
||||
if x_mask is not None or context_mask is not None:
|
||||
attn_mask = create_mask(q.shape, k.shape, q.device, q_mask=x_mask, k_mask=context_mask)
|
||||
else:
|
||||
attn_mask = None
|
||||
attn_out = attention(q, k, v, attn_mask=attn_mask)
|
||||
x = self.post_attention(x, attn_out)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MMDitSingleBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
nhead: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
pre_only: bool = False,
|
||||
kernel_size: int = 7,
|
||||
padding: int = 3):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
|
||||
self.attn = SelfAttention(dim, nhead)
|
||||
|
||||
self.pre_only = pre_only
|
||||
if pre_only:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
|
||||
else:
|
||||
if kernel_size == 1:
|
||||
self.linear1 = nn.Linear(dim, dim)
|
||||
else:
|
||||
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
|
||||
|
||||
if kernel_size == 1:
|
||||
self.ffn = MLP(dim, int(dim * mlp_ratio))
|
||||
else:
|
||||
self.ffn = ConvMLP(dim,
|
||||
int(dim * mlp_ratio),
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
|
||||
# x: BS * N * D
|
||||
# cond: BS * D
|
||||
modulation = self.adaLN_modulation(c)
|
||||
if self.pre_only:
|
||||
(shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
|
||||
gate_msa = shift_mlp = scale_mlp = gate_mlp = None
|
||||
else:
|
||||
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
|
||||
gate_mlp) = modulation.chunk(6, dim=-1)
|
||||
|
||||
x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
q, k, v = self.attn.pre_attention(x, rot)
|
||||
return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
|
||||
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]):
|
||||
if self.pre_only:
|
||||
return x
|
||||
|
||||
(gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
|
||||
x = x + self.linear1(attn_out) * gate_msa
|
||||
r = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = x + self.ffn(r) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor,
|
||||
rot: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
# x: BS * N * D
|
||||
# cond: BS * D
|
||||
x_qkv, x_conditions = self.pre_attention(x, cond, rot)
|
||||
attn_out = attention(*x_qkv)
|
||||
x = self.post_attention(x, attn_out, x_conditions)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class JointBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
|
||||
super().__init__()
|
||||
self.pre_only = pre_only
|
||||
self.latent_block = MMDitSingleBlock(dim,
|
||||
nhead,
|
||||
mlp_ratio,
|
||||
pre_only=False,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.clip_block = MMDitSingleBlock(dim,
|
||||
nhead,
|
||||
mlp_ratio,
|
||||
pre_only=pre_only,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
|
||||
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
|
||||
clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# latent: BS * N1 * D
|
||||
# clip_f: BS * N2 * D
|
||||
# c: BS * (1/N) * D
|
||||
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
|
||||
c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
|
||||
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
|
||||
|
||||
latent_len = latent.shape[1]
|
||||
clip_len = clip_f.shape[1]
|
||||
text_len = text_f.shape[1]
|
||||
|
||||
joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
|
||||
|
||||
attn_out = attention(*joint_qkv)
|
||||
x_attn_out = attn_out[:, :latent_len]
|
||||
c_attn_out = attn_out[:, latent_len:latent_len + clip_len]
|
||||
t_attn_out = attn_out[:, latent_len + clip_len:]
|
||||
|
||||
latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
|
||||
if not self.pre_only:
|
||||
clip_f = self.clip_block.post_attention(clip_f, c_attn_out, c_mod)
|
||||
text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
|
||||
|
||||
return latent, clip_f, text_f
|
||||
|
||||
def forward_debug(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
|
||||
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
|
||||
clip_rot: torch.Tensor,
|
||||
layer_idx: int = -1,
|
||||
) -> None:
|
||||
# latent: BS * N1 * D
|
||||
# clip_f: BS * N2 * D
|
||||
# c: BS * (1/N) * D
|
||||
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
|
||||
c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
|
||||
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
|
||||
|
||||
latent_len = latent.shape[1]
|
||||
clip_len = clip_f.shape[1]
|
||||
text_len = text_f.shape[1]
|
||||
|
||||
joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
|
||||
|
||||
attn_out = attention_debug(*joint_qkv, layer_idx=layer_idx)
|
||||
return None
|
||||
|
||||
|
||||
class FinalBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim):
|
||||
super().__init__()
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
|
||||
self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
|
||||
|
||||
def forward(self, latent, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
latent = modulate(self.norm(latent), shift, scale)
|
||||
latent = self.conv(latent)
|
||||
return latent
|
||||
@@ -0,0 +1,46 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self, rng: Optional[torch.Generator] = None):
|
||||
# x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
|
||||
r = torch.empty_like(self.mean).normal_(generator=rng)
|
||||
x = self.mean + self.std * r
|
||||
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
|
||||
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
|
||||
else:
|
||||
return 0.5 * (torch.pow(self.mean - other.mean, 2) / other.var +
|
||||
self.var / other.var - 1.0 - self.logvar + other.logvar)
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
from selva_core.utils.misc import instantiate_from_config
|
||||
from selva_core.model.networks_video_enc import TextSynch as TextSynchVideoEnc
|
||||
from selva_core.model.networks_generator import MMAudio
|
||||
|
||||
|
||||
_MODEL_ZOO = (TextSynchVideoEnc, MMAudio)
|
||||
|
||||
|
||||
def create_model_from_factory(factory_path: str, name: str, **kwargs) -> torch.nn.Module:
|
||||
"""
|
||||
Dynamically imports and calls a model factory function.
|
||||
"""
|
||||
params = {'name': name, **kwargs}
|
||||
model = instantiate_from_config(factory_path, params)
|
||||
assert isinstance(model, _MODEL_ZOO), f"Model {type(model)} is not a valid model type."
|
||||
return model
|
||||
@@ -0,0 +1,192 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from open_clip import create_model_from_pretrained
|
||||
from torchvision.transforms import Normalize
|
||||
from transformers import T5TokenizerFast, T5EncoderModel
|
||||
|
||||
from selva_core.ext.autoencoder import AutoEncoderModule
|
||||
from selva_core.ext.mel_converter import get_mel_converter
|
||||
from selva_core.ext.synchformer import Synchformer
|
||||
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||
from selva_core.utils.transforms import generate_multiple_segments
|
||||
|
||||
|
||||
def patch_clip(clip_model):
|
||||
# a hack to make it output last hidden states
|
||||
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
|
||||
def new_encode_text(self, text, normalize: bool = False):
|
||||
cast_dtype = self.transformer.get_cast_dtype()
|
||||
|
||||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.to(cast_dtype)
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
||||
return F.normalize(x, dim=-1) if normalize else x
|
||||
|
||||
clip_model.encode_text = new_encode_text.__get__(clip_model)
|
||||
return clip_model
|
||||
|
||||
|
||||
class FeaturesUtils(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tod_vae_ckpt: Optional[str] = None,
|
||||
bigvgan_vocoder_ckpt: Optional[str] = None,
|
||||
synchformer_ckpt: Optional[str] = None,
|
||||
enable_conditions: bool = True,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if enable_conditions:
|
||||
self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',
|
||||
return_transform=False)
|
||||
self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711])
|
||||
self.clip_model = patch_clip(self.clip_model)
|
||||
self.tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
|
||||
|
||||
self.synchformer = Synchformer(video=True, audio=False)
|
||||
self.synchformer.load_state_dict(
|
||||
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
|
||||
|
||||
self.text_encoder_t5 = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
self.tokenizer_t5 = T5TokenizerFast.from_pretrained('google/flan-t5-base')
|
||||
else:
|
||||
self.clip_model = None
|
||||
self.synchformer = None
|
||||
self.tokenizer_clip = None
|
||||
self.text_encoder_t5 = None
|
||||
self.tokenizer_t5 = None
|
||||
|
||||
if tod_vae_ckpt is not None:
|
||||
self.mel_converter = get_mel_converter(mode)
|
||||
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
||||
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
||||
mode=mode,
|
||||
need_vae_encoder=need_vae_encoder)
|
||||
else:
|
||||
self.tod = None
|
||||
|
||||
def compile(self):
|
||||
if self.clip_model is not None:
|
||||
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
||||
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
||||
if self.synchformer is not None:
|
||||
self.synchformer = torch.compile(self.synchformer)
|
||||
self.synchformer.forward_vfeat = torch.compile(self.synchformer.forward_vfeat)
|
||||
if self.text_encoder_t5 is not None:
|
||||
self.text_encoder_t5.forward = torch.compile(self.text_encoder_t5.forward)
|
||||
self.decode = torch.compile(self.decode)
|
||||
self.vocode = torch.compile(self.vocode)
|
||||
|
||||
def train(self, mode: bool) -> None:
|
||||
return super().train(False)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
||||
assert self.clip_model is not None, 'CLIP is not loaded'
|
||||
# x: (B, T, C, H, W) H/W: 384
|
||||
b, t, c, h, w = x.shape
|
||||
assert c == 3 and h == 384 and w == 384
|
||||
x = self.clip_preprocess(x)
|
||||
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
||||
outputs = []
|
||||
if batch_size < 0:
|
||||
batch_size = b * t
|
||||
for i in range(0, b * t, batch_size):
|
||||
outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True))
|
||||
x = torch.cat(outputs, dim=0)
|
||||
# x = self.clip_model.encode_image(x, normalize=True)
|
||||
x = rearrange(x, '(b t) d -> b t d', b=b)
|
||||
return x
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
||||
assert self.synchformer is not None, 'Synchformer is not loaded'
|
||||
# x: (B, T, C, H, W) H/W: 384
|
||||
|
||||
b, t, c, h, w = x.shape
|
||||
assert c == 3 and h == 224 and w == 224
|
||||
|
||||
# partition the video
|
||||
segment_size = 16
|
||||
step_size = 8
|
||||
x = generate_multiple_segments(x, segment_size, step_size) # (B, S, T, C, H, W)
|
||||
num_segments = x.shape[1]
|
||||
|
||||
outputs = []
|
||||
if batch_size < 0:
|
||||
batch_size = b
|
||||
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
||||
for i in range(0, b * num_segments, batch_size):
|
||||
outputs.append(self.synchformer.forward_vfeat(x[i:i + batch_size]))
|
||||
x = torch.cat(outputs, dim=0)
|
||||
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
||||
return x
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_text_clip(self, text: list[str]) -> torch.Tensor:
|
||||
assert self.clip_model is not None, 'CLIP is not loaded'
|
||||
assert self.tokenizer_clip is not None, 'Tokenizer is not loaded'
|
||||
# x: (B, L)
|
||||
tokens = self.tokenizer_clip(text).to(self.device)
|
||||
return self.clip_model.encode_text(tokens, normalize=True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_text_t5(self, text: list[str]) -> torch.Tensor:
|
||||
device = self.text_encoder_t5.device
|
||||
batch = self.tokenizer_t5(
|
||||
text,
|
||||
max_length=self.tokenizer_t5.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
|
||||
device
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.text_encoder_t5(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state # (B, L, D)
|
||||
|
||||
boolean_encoder_mask = (attention_mask == 1).to(device) # (B, L)
|
||||
|
||||
return encoder_hidden_states, boolean_encoder_mask
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||
assert self.tod is not None, 'VAE is not loaded'
|
||||
# x: (B * L)
|
||||
mel = self.mel_converter(x)
|
||||
dist = self.tod.encode(mel)
|
||||
|
||||
return dist
|
||||
|
||||
@torch.inference_mode()
|
||||
def vocode(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
assert self.tod is not None, 'VAE is not loaded'
|
||||
return self.tod.vocode(mel)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
assert self.tod is not None, 'VAE is not loaded'
|
||||
return self.tod.decode(z.transpose(1, 2))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
def get_parameter_groups(model, cfg, print_log=False):
|
||||
"""
|
||||
Assign different weight decays and learning rates to different parameters.
|
||||
Returns a parameter group which can be passed to the optimizer.
|
||||
"""
|
||||
weight_decay = cfg.weight_decay
|
||||
base_lr = cfg.learning_rate
|
||||
|
||||
params = []
|
||||
|
||||
# inspired by detectron2
|
||||
memo = set()
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
# Avoid duplicating parameters
|
||||
if param in memo:
|
||||
continue
|
||||
memo.add(param)
|
||||
|
||||
if name.startswith('module'):
|
||||
name = name[7:]
|
||||
|
||||
params.append(param)
|
||||
|
||||
parameter_groups = [
|
||||
{
|
||||
'params': params,
|
||||
'lr': base_lr,
|
||||
'weight_decay': weight_decay
|
||||
},
|
||||
]
|
||||
|
||||
return parameter_groups
|
||||
@@ -0,0 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def log_normal_sample(x: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
m: float = 0.0,
|
||||
s: float = 1.0) -> torch.Tensor:
|
||||
bs = x.shape[0]
|
||||
s = torch.randn(bs, device=x.device, generator=generator) * s + m
|
||||
return torch.sigmoid(s)
|
||||
Reference in New Issue
Block a user