From 63bd999dfa78850eb5a8a63f0e61bcb6e9fe166f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 21:07:17 +0100 Subject: [PATCH] fix: switch to VideoPrism large (1024-dim) and fix Synchformer output shape MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit prismaudio.json conditioner config requires: - video_features: dim=1024 → switch videoprism_public_v1_base → large (ViT-L) - sync_features: dim=768, length divisible by 8 → expand [num_seg,768] to [num_seg*8,768] (per-frame) so Sync_MLP can reshape by groups of 8 Co-Authored-By: Claude Sonnet 4.6 --- data_utils/v2a_utils/feature_utils_288.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index 5ddd3f9..f7f26f5 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -70,9 +70,9 @@ class FeaturesUtils: return from videoprism import models as vp import jax - print("[FeaturesUtils] Loading VideoPrism...") - self._vp_model = vp.get_model("videoprism_public_v1_base") - self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base") + print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...") + self._vp_model = vp.get_model("videoprism_public_v1_large") + self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large") self._jax_forward = jax.jit( lambda x: self._vp_model.apply(self._vp_state, x, train=False) ) @@ -233,5 +233,7 @@ class _SynchformerVisualEncoder(nn.Module): # Mean-pool over frames and spatial dims → [C*H*W] → project pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W] feat = self._linear(pooled.unsqueeze(0)) # [1, dim] - segs.append(feat) - return torch.cat(segs, dim=0) # [num_seg, 768] + # Repeat feature once per frame so output is [num_seg*8, dim] + # Sync_MLP expects per-frame features grouped in 8-frame segments + segs.append(feat.expand(seg, -1)) + return torch.cat(segs, dim=0) # [num_seg*8, 768]