140cc5ee9a
Replace placeholder single-linear with proper architecture reverse-engineered from synchformer_state_dict.pth: - _PatchEmbed: Conv2d(3, 768, 16x16) → [B, 196, 768] - _TimeSformerBlock: factorized spatial + temporal attention (norm1/attn/norm3/timeattn/norm2/mlp) - _SpatialAttnAgg: TransformerEncoderLayer with CLS token, aggregates 196 patches → 1/frame - 12 blocks, dim=768, 8 frames/segment - Loads from vfeat_extractor.* prefix, skips 3D patch embed Output: [T_aligned, 768] per-frame features for Sync_MLP conditioner. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
328 lines
12 KiB
Python
328 lines
12 KiB
Python
"""
|
|
PrismAudio feature extraction utilities.
|
|
|
|
Implements FeaturesUtils used by scripts/extract_features.py to extract:
|
|
- Text features via T5-Gemma (transformers)
|
|
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
|
|
- Sync features via Synchformer visual encoder (PyTorch)
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
class FeaturesUtils:
|
|
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
|
|
self.device = device or torch.device("cpu")
|
|
self._t5_tokenizer = None
|
|
self._t5_encoder = None
|
|
self._vp_model = None
|
|
self._vp_state = None
|
|
self._sync_model = None
|
|
|
|
self._synchformer_ckpt = synchformer_ckpt
|
|
self._load_synchformer()
|
|
|
|
# ------------------------------------------------------------------
|
|
# T5-Gemma text encoding
|
|
# ------------------------------------------------------------------
|
|
|
|
def _ensure_t5(self):
|
|
if self._t5_encoder is not None:
|
|
return
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
model_id = "google/t5gemma-l-l-ul2-it"
|
|
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
|
|
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
self._t5_encoder = (
|
|
AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
|
.get_encoder()
|
|
.to(self.device)
|
|
.eval()
|
|
)
|
|
|
|
def encode_t5_text(self, texts):
|
|
"""
|
|
Args:
|
|
texts: list of str
|
|
Returns:
|
|
Tensor [seq_len, 1024]
|
|
"""
|
|
self._ensure_t5()
|
|
tokens = self._t5_tokenizer(
|
|
texts, return_tensors="pt", padding=True
|
|
).to(self.device)
|
|
with torch.no_grad():
|
|
out = self._t5_encoder(**tokens)
|
|
# Move encoder off GPU to save VRAM
|
|
self._t5_encoder.to("cpu")
|
|
torch.cuda.empty_cache()
|
|
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
|
|
|
|
# ------------------------------------------------------------------
|
|
# VideoPrism video + text encoding (JAX)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _ensure_videoprism(self):
|
|
if self._vp_model is not None:
|
|
return
|
|
from videoprism import models as vp
|
|
import jax
|
|
print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...")
|
|
self._vp_model = vp.get_model("videoprism_public_v1_large")
|
|
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large")
|
|
self._jax_forward = jax.jit(
|
|
lambda x: self._vp_model.apply(self._vp_state, x, train=False)
|
|
)
|
|
|
|
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
|
"""
|
|
Args:
|
|
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
|
texts: list of str (unused — VideoPrism is vision-only;
|
|
global_text_features returned as zeros placeholder)
|
|
Returns:
|
|
global_video_features: Tensor [1, D]
|
|
video_features: Tensor [T, D]
|
|
global_text_features: Tensor [1, D] (zeros — no text tower)
|
|
"""
|
|
self._ensure_videoprism()
|
|
import jax.numpy as jnp
|
|
|
|
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
|
frames = clip_input.squeeze(0) # [T, C, H, W]
|
|
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
|
|
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
|
|
frames_np = frames.cpu().numpy().astype(np.float32)
|
|
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
|
|
|
embeddings, _ = self._jax_forward(frames_jax) # [1, T*N, D]
|
|
|
|
# Convert back to torch
|
|
embeddings_np = np.array(embeddings) # [1, T*N, D]
|
|
emb = torch.from_numpy(embeddings_np).to(self.device) # [1, T*N, D]
|
|
|
|
T = frames_np.shape[0]
|
|
D = emb.shape[-1]
|
|
N = emb.shape[1] // T # spatial patches per frame
|
|
|
|
# Global video: mean over all tokens
|
|
global_video = emb.mean(dim=1) # [1, D]
|
|
|
|
# Per-frame: mean over spatial patches
|
|
per_frame = emb.view(1, T, N, D).mean(dim=2).squeeze(0) # [T, D]
|
|
|
|
# Text features: zeros (VideoPrism public model is vision-only)
|
|
global_text = torch.zeros(1, D, device=self.device)
|
|
|
|
return global_video, per_frame, global_text
|
|
|
|
# ------------------------------------------------------------------
|
|
# Synchformer sync feature encoding
|
|
# ------------------------------------------------------------------
|
|
|
|
def _load_synchformer(self):
|
|
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
|
|
return
|
|
|
|
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
|
|
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
|
|
|
|
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
|
|
if isinstance(state, dict) and "model" in state:
|
|
state_dict = state["model"]
|
|
else:
|
|
state_dict = state
|
|
|
|
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
|
|
self._sync_model.eval()
|
|
|
|
def encode_video_with_sync(self, sync_input):
|
|
"""
|
|
Args:
|
|
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
|
Returns:
|
|
sync_features: Tensor [num_segments, 768]
|
|
"""
|
|
if self._sync_model is None:
|
|
raise RuntimeError(
|
|
"[FeaturesUtils] Synchformer checkpoint not loaded. "
|
|
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
|
|
)
|
|
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
|
|
with torch.no_grad():
|
|
return self._sync_model(frames)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Synchformer visual encoder — TimeSformer-style ViT-B/16
|
|
# Architecture reverse-engineered from synchformer_state_dict.pth
|
|
# ------------------------------------------------------------------
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class _PatchEmbed(nn.Module):
|
|
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
|
|
|
|
def forward(self, x):
|
|
return self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
|
|
class _ViTAttn(nn.Module):
|
|
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
|
|
def __init__(self, dim=768, num_heads=12):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
self.qkv = nn.Linear(dim, dim * 3)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x):
|
|
B, N, D = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv.unbind(0)
|
|
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
|
|
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
|
|
|
|
|
|
class _BlockMLP(nn.Module):
|
|
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
|
|
def __init__(self, dim=768, mlp_dim=3072):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(dim, mlp_dim)
|
|
self.fc2 = nn.Linear(mlp_dim, dim)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(F.gelu(self.fc1(x)))
|
|
|
|
|
|
class _TimeSformerBlock(nn.Module):
|
|
"""
|
|
Factorized space-time attention block.
|
|
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
|
|
"""
|
|
def __init__(self, dim=768, num_heads=12):
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.attn = _ViTAttn(dim, num_heads)
|
|
self.norm3 = nn.LayerNorm(dim)
|
|
self.timeattn = _ViTAttn(dim, num_heads)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.mlp = _BlockMLP(dim)
|
|
|
|
def forward(self, x, T):
|
|
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
|
|
x = x + self.attn(self.norm1(x))
|
|
# Temporal attention: for each spatial position, attend across T frames
|
|
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
|
|
xt = x.permute(1, 0, 2)
|
|
xt = xt + self.timeattn(self.norm3(xt))
|
|
x = xt.permute(1, 0, 2)
|
|
x = x + self.mlp(self.norm2(x))
|
|
return x
|
|
|
|
|
|
class _SpatialAttnAgg(nn.Module):
|
|
"""
|
|
Aggregates 196 spatial patches → 1 feature per frame using a
|
|
TransformerEncoderLayer with a learnable CLS token.
|
|
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
|
|
"""
|
|
def __init__(self, dim=768, num_heads=12):
|
|
super().__init__()
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
|
|
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
|
self.linear1 = nn.Linear(dim, dim * 4)
|
|
self.linear2 = nn.Linear(dim * 4, dim)
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
|
|
def forward(self, x):
|
|
# x: [T, 196, 768] — spatial patches (CLS stripped)
|
|
T = x.shape[0]
|
|
cls = self.cls_token.expand(T, -1, -1)
|
|
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
|
|
xn = self.norm1(x)
|
|
x = x + self.self_attn(xn, xn, xn)[0]
|
|
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
|
|
return x[:, 0, :] # [T, 768] — CLS per frame
|
|
|
|
|
|
class _SynchformerVisualEncoder(nn.Module):
|
|
"""
|
|
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
|
|
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
|
|
"""
|
|
|
|
def __init__(self, state_dict, device):
|
|
super().__init__()
|
|
self.device = device
|
|
self.segment_frames = 8
|
|
|
|
self.patch_embed = _PatchEmbed()
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
|
|
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
|
|
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
|
|
self.norm = nn.LayerNorm(768)
|
|
self.spatial_attn_agg = _SpatialAttnAgg()
|
|
|
|
# Load weights from vfeat_extractor.* prefix
|
|
prefix = "vfeat_extractor."
|
|
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
|
# Exclude 3D patch embed (we use 2D only)
|
|
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
|
|
missing, unexpected = self.load_state_dict(sub, strict=False)
|
|
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
|
|
if missing:
|
|
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
|
|
|
|
self.to(device)
|
|
|
|
def forward(self, frames):
|
|
"""
|
|
Args:
|
|
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
|
|
Returns:
|
|
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
|
|
"""
|
|
T = frames.shape[0]
|
|
seg = self.segment_frames
|
|
num_seg = max(1, T // seg)
|
|
T_aligned = num_seg * seg
|
|
|
|
results = []
|
|
for i in range(num_seg):
|
|
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
|
|
results.append(self._forward_segment(chunk))
|
|
return torch.cat(results, dim=0) # [T_aligned, 768]
|
|
|
|
def _forward_segment(self, x):
|
|
# x: [8, 3, 224, 224]
|
|
T = x.shape[0] # 8
|
|
|
|
# Patch embedding + CLS token
|
|
x = self.patch_embed(x) # [8, 196, 768]
|
|
cls = self.cls_token.expand(T, -1, -1)
|
|
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
|
|
|
|
# Positional + temporal embeddings
|
|
x = x + self.pos_embed # broadcast (1,197,768)
|
|
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
|
|
|
|
# Transformer blocks (factorized space-time)
|
|
for block in self.blocks:
|
|
x = block(x, T)
|
|
|
|
x = self.norm(x)
|
|
|
|
# Aggregate spatial patches → 1 feature per frame
|
|
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
|