fix: switch to VideoPrism large (1024-dim) and fix Synchformer output shape

prismaudio.json conditioner config requires:
- video_features: dim=1024 → switch videoprism_public_v1_base → large (ViT-L)
- sync_features: dim=768, length divisible by 8 → expand [num_seg,768] to
  [num_seg*8,768] (per-frame) so Sync_MLP can reshape by groups of 8

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:07:17 +01:00
parent 20fb766ad2
commit 63bd999dfa
+7 -5
View File
@@ -70,9 +70,9 @@ class FeaturesUtils:
return return
from videoprism import models as vp from videoprism import models as vp
import jax import jax
print("[FeaturesUtils] Loading VideoPrism...") print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...")
self._vp_model = vp.get_model("videoprism_public_v1_base") self._vp_model = vp.get_model("videoprism_public_v1_large")
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base") self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large")
self._jax_forward = jax.jit( self._jax_forward = jax.jit(
lambda x: self._vp_model.apply(self._vp_state, x, train=False) lambda x: self._vp_model.apply(self._vp_state, x, train=False)
) )
@@ -233,5 +233,7 @@ class _SynchformerVisualEncoder(nn.Module):
# Mean-pool over frames and spatial dims → [C*H*W] → project # Mean-pool over frames and spatial dims → [C*H*W] → project
pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W] pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W]
feat = self._linear(pooled.unsqueeze(0)) # [1, dim] feat = self._linear(pooled.unsqueeze(0)) # [1, dim]
segs.append(feat) # Repeat feature once per frame so output is [num_seg*8, dim]
return torch.cat(segs, dim=0) # [num_seg, 768] # Sync_MLP expects per-frame features grouped in 8-frame segments
segs.append(feat.expand(seg, -1))
return torch.cat(segs, dim=0) # [num_seg*8, 768]