diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index f7f26f5..8ae576f 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -157,17 +157,108 @@ class FeaturesUtils: # ------------------------------------------------------------------ -# Synchformer visual encoder — loads from the PrismAudio checkpoint +# Synchformer visual encoder — TimeSformer-style ViT-B/16 +# Architecture reverse-engineered from synchformer_state_dict.pth # ------------------------------------------------------------------ +import torch.nn.functional as F + + +class _PatchEmbed(nn.Module): + """2D patch embedding: [B, 3, 224, 224] → [B, 196, 768].""" + def __init__(self): + super().__init__() + self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16) + + def forward(self, x): + return self.proj(x).flatten(2).transpose(1, 2) + + +class _ViTAttn(nn.Module): + """ViT-style QKV attention (timm convention: qkv as single Linear).""" + def __init__(self, dim=768, num_heads=12): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B, N, D = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1) + return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D)) + + +class _BlockMLP(nn.Module): + """Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint.""" + def __init__(self, dim=768, mlp_dim=3072): + super().__init__() + self.fc1 = nn.Linear(dim, mlp_dim) + self.fc2 = nn.Linear(mlp_dim, dim) + + def forward(self, x): + return self.fc2(F.gelu(self.fc1(x))) + + +class _TimeSformerBlock(nn.Module): + """ + Factorized space-time attention block. + norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP + """ + def __init__(self, dim=768, num_heads=12): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = _ViTAttn(dim, num_heads) + self.norm3 = nn.LayerNorm(dim) + self.timeattn = _ViTAttn(dim, num_heads) + self.norm2 = nn.LayerNorm(dim) + self.mlp = _BlockMLP(dim) + + def forward(self, x, T): + # x: [T, N, D] (T frames treated as batch, N=197 spatial tokens) + x = x + self.attn(self.norm1(x)) + # Temporal attention: for each spatial position, attend across T frames + # [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D] + xt = x.permute(1, 0, 2) + xt = xt + self.timeattn(self.norm3(xt)) + x = xt.permute(1, 0, 2) + x = x + self.mlp(self.norm2(x)) + return x + + +class _SpatialAttnAgg(nn.Module): + """ + Aggregates 196 spatial patches → 1 feature per frame using a + TransformerEncoderLayer with a learnable CLS token. + Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2. + """ + def __init__(self, dim=768, num_heads=12): + super().__init__() + self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) + self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) + self.linear1 = nn.Linear(dim, dim * 4) + self.linear2 = nn.Linear(dim * 4, dim) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward(self, x): + # x: [T, 196, 768] — spatial patches (CLS stripped) + T = x.shape[0] + cls = self.cls_token.expand(T, -1, -1) + x = torch.cat([cls, x], dim=1) # [T, 197, 768] + xn = self.norm1(x) + x = x + self.self_attn(xn, xn, xn)[0] + x = x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) + return x[:, 0, :] # [T, 768] — CLS per frame + + class _SynchformerVisualEncoder(nn.Module): """ - Minimal visual feature extractor compatible with the PrismAudio - synchformer_state_dict.pth checkpoint. - - Inspects the state dict key prefixes to route to the right sub-module. - The encoder processes video in segments of 8 frames and returns - [num_segments, 768] features for the Sync_MLP conditioner. + TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint. + Processes video in segments of 8 frames → [T_aligned, 768] per-frame features. """ def __init__(self, state_dict, device): @@ -175,65 +266,62 @@ class _SynchformerVisualEncoder(nn.Module): self.device = device self.segment_frames = 8 - # Determine architecture from state dict keys - keys = list(state_dict.keys()) - prefix = self._detect_prefix(keys) - print(f"[FeaturesUtils] Synchformer state dict prefix detected: '{prefix}'") + self.patch_embed = _PatchEmbed() + self.cls_token = nn.Parameter(torch.zeros(1, 1, 768)) + self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768)) + self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768)) + self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)]) + self.norm = nn.LayerNorm(768) + self.spatial_attn_agg = _SpatialAttnAgg() - self._build_from_state_dict(state_dict, prefix, device) - - def _detect_prefix(self, keys): - for candidate in ("vfeat_extractor.", "visual_encoder.", "encoder.", ""): - if any(k.startswith(candidate) for k in keys): - return candidate - return "" - - def _build_from_state_dict(self, state_dict, prefix, device): - # Extract sub-dict for the visual encoder - if prefix: - sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - else: - sub = state_dict - - # Detect output dimension from the last linear layer - dim = 768 - for k, v in reversed(list(sub.items())): - if v.dim() == 2: - dim = v.shape[0] - break - - # Build a simple temporal MLP that projects patch tokens → segment features. - # If the checkpoint has a known transformer structure we load it; otherwise - # we fall back to a projection layer that maps mean-pooled tokens to 768-d. - self._dim = dim - self._linear = nn.Linear(3 * 224 * 224, dim, bias=False) - - # Try to load the sub-dict; ignore mismatches gracefully - try: - missing, unexpected = self.load_state_dict({"_linear.weight": sub.get("proj.weight", sub.get("head.weight", next(iter(sub.values()))))}, strict=False) - print(f"[FeaturesUtils] Synchformer loaded (missing={len(missing)}, unexpected={len(unexpected)})") - except Exception as e: - print(f"[FeaturesUtils] Warning: could not load Synchformer weights ({e}). Using random init.") + # Load weights from vfeat_extractor.* prefix + prefix = "vfeat_extractor." + sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + # Exclude 3D patch embed (we use 2D only) + sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")} + missing, unexpected = self.load_state_dict(sub, strict=False) + print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}") + if missing: + print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}") self.to(device) def forward(self, frames): """ Args: - frames: Tensor [T, C, H, W] — video frames at 25fps, normalised to [-1,1] + frames: [T, C, H, W] float32 in [-1, 1], at 25fps Returns: - Tensor [num_segments, 768] + [T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8) """ T = frames.shape[0] seg = self.segment_frames num_seg = max(1, T // seg) - segs = [] + T_aligned = num_seg * seg + + results = [] for i in range(num_seg): - chunk = frames[i * seg: (i + 1) * seg] # [seg, C, H, W] - # Mean-pool over frames and spatial dims → [C*H*W] → project - pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W] - feat = self._linear(pooled.unsqueeze(0)) # [1, dim] - # Repeat feature once per frame so output is [num_seg*8, dim] - # Sync_MLP expects per-frame features grouped in 8-frame segments - segs.append(feat.expand(seg, -1)) - return torch.cat(segs, dim=0) # [num_seg*8, 768] + chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W] + results.append(self._forward_segment(chunk)) + return torch.cat(results, dim=0) # [T_aligned, 768] + + def _forward_segment(self, x): + # x: [8, 3, 224, 224] + T = x.shape[0] # 8 + + # Patch embedding + CLS token + x = self.patch_embed(x) # [8, 196, 768] + cls = self.cls_token.expand(T, -1, -1) + x = torch.cat([cls, x], dim=1) # [8, 197, 768] + + # Positional + temporal embeddings + x = x + self.pos_embed # broadcast (1,197,768) + x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast + + # Transformer blocks (factorized space-time) + for block in self.blocks: + x = block(x, T) + + x = self.norm(x) + + # Aggregate spatial patches → 1 feature per frame + return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]