feat: implement real Synchformer visual encoder (TimeSformer ViT-B/16)

Replace placeholder single-linear with proper architecture reverse-engineered
from synchformer_state_dict.pth:
- _PatchEmbed: Conv2d(3, 768, 16x16) → [B, 196, 768]
- _TimeSformerBlock: factorized spatial + temporal attention (norm1/attn/norm3/timeattn/norm2/mlp)
- _SpatialAttnAgg: TransformerEncoderLayer with CLS token, aggregates 196 patches → 1/frame
- 12 blocks, dim=768, 8 frames/segment
- Loads from vfeat_extractor.* prefix, skips 3D patch embed

Output: [T_aligned, 768] per-frame features for Sync_MLP conditioner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:28:20 +01:00
parent f99d2666e8
commit 140cc5ee9a
+144 -56
View File
@@ -157,17 +157,108 @@ class FeaturesUtils:
# ------------------------------------------------------------------
# Synchformer visual encoder — loads from the PrismAudio checkpoint
# Synchformer visual encoder — TimeSformer-style ViT-B/16
# Architecture reverse-engineered from synchformer_state_dict.pth
# ------------------------------------------------------------------
import torch.nn.functional as F
class _PatchEmbed(nn.Module):
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
def __init__(self):
super().__init__()
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
def forward(self, x):
return self.proj(x).flatten(2).transpose(1, 2)
class _ViTAttn(nn.Module):
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
class _BlockMLP(nn.Module):
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
def __init__(self, dim=768, mlp_dim=3072):
super().__init__()
self.fc1 = nn.Linear(dim, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, dim)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _TimeSformerBlock(nn.Module):
"""
Factorized space-time attention block.
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
"""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = _ViTAttn(dim, num_heads)
self.norm3 = nn.LayerNorm(dim)
self.timeattn = _ViTAttn(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = _BlockMLP(dim)
def forward(self, x, T):
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
x = x + self.attn(self.norm1(x))
# Temporal attention: for each spatial position, attend across T frames
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
xt = x.permute(1, 0, 2)
xt = xt + self.timeattn(self.norm3(xt))
x = xt.permute(1, 0, 2)
x = x + self.mlp(self.norm2(x))
return x
class _SpatialAttnAgg(nn.Module):
"""
Aggregates 196 spatial patches → 1 feature per frame using a
TransformerEncoderLayer with a learnable CLS token.
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
"""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.linear1 = nn.Linear(dim, dim * 4)
self.linear2 = nn.Linear(dim * 4, dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# x: [T, 196, 768] — spatial patches (CLS stripped)
T = x.shape[0]
cls = self.cls_token.expand(T, -1, -1)
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
xn = self.norm1(x)
x = x + self.self_attn(xn, xn, xn)[0]
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
return x[:, 0, :] # [T, 768] — CLS per frame
class _SynchformerVisualEncoder(nn.Module):
"""
Minimal visual feature extractor compatible with the PrismAudio
synchformer_state_dict.pth checkpoint.
Inspects the state dict key prefixes to route to the right sub-module.
The encoder processes video in segments of 8 frames and returns
[num_segments, 768] features for the Sync_MLP conditioner.
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
"""
def __init__(self, state_dict, device):
@@ -175,65 +266,62 @@ class _SynchformerVisualEncoder(nn.Module):
self.device = device
self.segment_frames = 8
# Determine architecture from state dict keys
keys = list(state_dict.keys())
prefix = self._detect_prefix(keys)
print(f"[FeaturesUtils] Synchformer state dict prefix detected: '{prefix}'")
self.patch_embed = _PatchEmbed()
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
self.norm = nn.LayerNorm(768)
self.spatial_attn_agg = _SpatialAttnAgg()
self._build_from_state_dict(state_dict, prefix, device)
def _detect_prefix(self, keys):
for candidate in ("vfeat_extractor.", "visual_encoder.", "encoder.", ""):
if any(k.startswith(candidate) for k in keys):
return candidate
return ""
def _build_from_state_dict(self, state_dict, prefix, device):
# Extract sub-dict for the visual encoder
if prefix:
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
else:
sub = state_dict
# Detect output dimension from the last linear layer
dim = 768
for k, v in reversed(list(sub.items())):
if v.dim() == 2:
dim = v.shape[0]
break
# Build a simple temporal MLP that projects patch tokens → segment features.
# If the checkpoint has a known transformer structure we load it; otherwise
# we fall back to a projection layer that maps mean-pooled tokens to 768-d.
self._dim = dim
self._linear = nn.Linear(3 * 224 * 224, dim, bias=False)
# Try to load the sub-dict; ignore mismatches gracefully
try:
missing, unexpected = self.load_state_dict({"_linear.weight": sub.get("proj.weight", sub.get("head.weight", next(iter(sub.values()))))}, strict=False)
print(f"[FeaturesUtils] Synchformer loaded (missing={len(missing)}, unexpected={len(unexpected)})")
except Exception as e:
print(f"[FeaturesUtils] Warning: could not load Synchformer weights ({e}). Using random init.")
# Load weights from vfeat_extractor.* prefix
prefix = "vfeat_extractor."
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
# Exclude 3D patch embed (we use 2D only)
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
missing, unexpected = self.load_state_dict(sub, strict=False)
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
if missing:
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
self.to(device)
def forward(self, frames):
"""
Args:
frames: Tensor [T, C, H, W] — video frames at 25fps, normalised to [-1,1]
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
Returns:
Tensor [num_segments, 768]
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
"""
T = frames.shape[0]
seg = self.segment_frames
num_seg = max(1, T // seg)
segs = []
T_aligned = num_seg * seg
results = []
for i in range(num_seg):
chunk = frames[i * seg: (i + 1) * seg] # [seg, C, H, W]
# Mean-pool over frames and spatial dims → [C*H*W] → project
pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W]
feat = self._linear(pooled.unsqueeze(0)) # [1, dim]
# Repeat feature once per frame so output is [num_seg*8, dim]
# Sync_MLP expects per-frame features grouped in 8-frame segments
segs.append(feat.expand(seg, -1))
return torch.cat(segs, dim=0) # [num_seg*8, 768]
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
results.append(self._forward_segment(chunk))
return torch.cat(results, dim=0) # [T_aligned, 768]
def _forward_segment(self, x):
# x: [8, 3, 224, 224]
T = x.shape[0] # 8
# Patch embedding + CLS token
x = self.patch_embed(x) # [8, 196, 768]
cls = self.cls_token.expand(T, -1, -1)
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
# Positional + temporal embeddings
x = x + self.pos_embed # broadcast (1,197,768)
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
# Transformer blocks (factorized space-time)
for block in self.blocks:
x = block(x, T)
x = self.norm(x)
# Aggregate spatial patches → 1 feature per frame
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]