Files
Ethanfel 6bc3fd6443 chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:18:09 +02:00

199 lines
7.6 KiB
Python

import logging
from typing import Any, Mapping
import einops
import torch
from torch import nn
from selva_core.ext.synchformer.motionformer import MotionFormer, SpatialTransformerEncoderLayer, BaseEncoderLayer
from selva_core.ext.synchformer.astransformer import AST
from selva_core.model.transformer_layers import (MMCrossAttentionBlock)
from selva_core.model.low_level import MLP
class ExtendedMotionFormer(MotionFormer):
"""Extended MotionFormer with additional methods for text synchronization."""
def forward_segments_without_aggregation(self, x, orig_shape: tuple) -> tuple[torch.Tensor, torch.Tensor]:
"""
Extract features without spatial-temporal aggregation.
Args:
x: Input tensor of shape (BS, C, T, H, W) where S is the number of segments
orig_shape: Original shape tuple (B, S, C, T, H, W)
Returns:
Tuple of (features, mask) where features are of shape (B*S, D, t, h, w)
"""
x, x_mask = self.forward_features(x)
assert self.extract_features
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
x = x[:, 1:, :] # without the CLS token for efficiency
x = self.norm(x)
x = self.pre_logits(x)
if self.factorize_space_time:
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
return x, x_mask
def spatiotemporal_aggregation(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
"""
Apply spatial-temporal aggregation to features.
Args:
x: Features tensor of shape (B*S, D, t, h, w)
x_mask: Mask tensor
Returns:
Aggregated features of shape (B*S, D) or (B*S, t, D)
"""
if self.factorize_space_time:
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
x = self.temp_attn_agg(x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
return x
class TextSynchformer(nn.Module):
def __init__(self, video: bool = True, audio: bool = False,
text_dim: int = 1024, max_text_seq_len: int = 512, xattn_depth: int = 1):
super().__init__()
self.video = video
self.audio = audio
self.text_dim = text_dim
self.max_text_seq_len = max_text_seq_len
if not video and not audio:
raise ValueError('At least one of video or audio should be True.')
# Use ExtendedMotionFormer directly instead of inheriting from Synchformer
if self.video:
self.vfeat_extractor = ExtendedMotionFormer(
extract_features=True,
factorize_space_time=True,
agg_space_module='TransformerEncoderLayer',
agg_time_module='torch.nn.Identity',
add_global_repr=False
)
if self.audio:
self.afeat_extractor = AST(
extract_features=True,
max_spec_t=66,
factorize_freq_time=True,
agg_freq_module='TransformerEncoderLayer',
agg_time_module='torch.nn.Identity',
add_global_repr=False
)
# Get embedding dimensions from the video feature extractor
if self.video:
self.embed_dim = self.vfeat_extractor.embed_dim
self.num_heads = self.vfeat_extractor.num_heads
self.mlp_ratio = self.vfeat_extractor.mlp_ratio
else:
# Default values if no video
self.embed_dim = 768
self.num_heads = 12
self.mlp_ratio = 4
self.text_proj = nn.Sequential(
nn.Linear(self.text_dim, self.embed_dim),
nn.SiLU(),
MLP(self.embed_dim, self.embed_dim * 4)
)
self.synch_text_cross_blocks = nn.ModuleList([
MMCrossAttentionBlock(self.embed_dim, self.num_heads,
mlp_ratio=self.mlp_ratio,
kernel_size=1, padding=0,
residual=True)
for _ in range(xattn_depth)
])
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.text_proj.apply(_basic_init)
self.synch_text_cross_blocks.apply(_basic_init)
for block in self.synch_text_cross_blocks:
nn.init.constant_(block.norm1.weight, 0.0)
nn.init.constant_(block.norm1.bias, 0.0)
nn.init.constant_(block.ffn.w2.weight, 0.0)
def forward(self, data, text_features):
video, audio = None, None
if self.video and self.audio:
video, audio = data
elif self.video:
video = data
elif self.audio:
audio = data
if self.video and video is not None:
video = self.forward_vfeat(video, text_features)
if self.audio and audio is not None:
audio = self.forward_afeat(audio, text_features)
if self.video and self.audio:
return video, audio
elif self.video:
return video
else:
return audio
def forward_vfeat(self, vis, text_f, text_mask):
B, S, Tv, C, H, W = vis.shape
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
# Flatten for processing
orig_shape = (B, S, C, Tv, H, W)
vis = einops.rearrange(vis, 'B S C Tv H W -> (B S) C Tv H W') # vis.view(B * S, C, Tv, H, W)
vis, vis_mask = self.vfeat_extractor.forward_segments_without_aggregation(
vis, orig_shape # B*S D t h w , BS t h w
)
text_f = self.text_proj(text_f) # (B, text_dim) -> (B, embed_dim)
BS, D, t, h, w = vis.shape
vis = einops.rearrange(vis, '(B S) D t h w -> B (S t h w) D', B=B, S=S)
vis_mask = einops.rearrange(vis_mask, '(B S) t h w -> B (S t h w)', B=B, S=S) \
if vis_mask is not None else None
for block in self.synch_text_cross_blocks:
vis = block(vis, text_f, rot=None, x_mask=vis_mask, context_mask=text_mask)
vis = einops.rearrange(vis, 'B (S t h w) D -> (B S) D t h w', B=B, S=S, D=D, t=t, h=h, w=w)
vis_mask = einops.rearrange(vis_mask, 'B (S t h w) -> (B S) t h w', B=B, S=S, t=t, h=h, w=w) \
if vis_mask is not None else None
vis = self.vfeat_extractor.spatiotemporal_aggregation(
vis, vis_mask
)
vis = vis.view(B, S, *vis.shape[1:])
return vis
def forward_afeat(self, aud):
"""Forward audio features."""
raise NotImplementedError("Audio feature extraction is not implemented in TextSynchformer.")
# B, S, F, Ta = aud.shape
# aud = aud.permute(0, 1, 3, 2) # (B, S, Ta, F)
# aud, _ = self.afeat_extractor(aud)
# return aud
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
target_keys = (['vfeat_extractor'] if self.video else []) \
+ (['afeat_extractor'] if self.audio else []) \
+ ['text_proj', 'synch_text_cross_blocks']
# discard all entries except vfeat_extractor / afeat_extractor
sd = {k: v for k, v in sd.items() if any(k.startswith(tk)
for tk in target_keys)}
return super().load_state_dict(sd, strict)