fix: switch to VideoPrism large (1024-dim) and fix Synchformer output shape
prismaudio.json conditioner config requires: - video_features: dim=1024 → switch videoprism_public_v1_base → large (ViT-L) - sync_features: dim=768, length divisible by 8 → expand [num_seg,768] to [num_seg*8,768] (per-frame) so Sync_MLP can reshape by groups of 8 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -70,9 +70,9 @@ class FeaturesUtils:
|
||||
return
|
||||
from videoprism import models as vp
|
||||
import jax
|
||||
print("[FeaturesUtils] Loading VideoPrism...")
|
||||
self._vp_model = vp.get_model("videoprism_public_v1_base")
|
||||
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base")
|
||||
print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...")
|
||||
self._vp_model = vp.get_model("videoprism_public_v1_large")
|
||||
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large")
|
||||
self._jax_forward = jax.jit(
|
||||
lambda x: self._vp_model.apply(self._vp_state, x, train=False)
|
||||
)
|
||||
@@ -233,5 +233,7 @@ class _SynchformerVisualEncoder(nn.Module):
|
||||
# Mean-pool over frames and spatial dims → [C*H*W] → project
|
||||
pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W]
|
||||
feat = self._linear(pooled.unsqueeze(0)) # [1, dim]
|
||||
segs.append(feat)
|
||||
return torch.cat(segs, dim=0) # [num_seg, 768]
|
||||
# Repeat feature once per frame so output is [num_seg*8, dim]
|
||||
# Sync_MLP expects per-frame features grouped in 8-frame segments
|
||||
segs.append(feat.expand(seg, -1))
|
||||
return torch.cat(segs, dim=0) # [num_seg*8, 768]
|
||||
|
||||
Reference in New Issue
Block a user