diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index 5ddd3f9..f7f26f5 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -70,9 +70,9 @@ class FeaturesUtils: return from videoprism import models as vp import jax - print("[FeaturesUtils] Loading VideoPrism...") - self._vp_model = vp.get_model("videoprism_public_v1_base") - self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base") + print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...") + self._vp_model = vp.get_model("videoprism_public_v1_large") + self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large") self._jax_forward = jax.jit( lambda x: self._vp_model.apply(self._vp_state, x, train=False) ) @@ -233,5 +233,7 @@ class _SynchformerVisualEncoder(nn.Module): # Mean-pool over frames and spatial dims → [C*H*W] → project pooled = chunk.mean(dim=0).reshape(-1) # [C*H*W] feat = self._linear(pooled.unsqueeze(0)) # [1, dim] - segs.append(feat) - return torch.cat(segs, dim=0) # [num_seg, 768] + # Repeat feature once per frame so output is [num_seg*8, dim] + # Sync_MLP expects per-frame features grouped in 8-frame segments + segs.append(feat.expand(seg, -1)) + return torch.cat(segs, dim=0) # [num_seg*8, 768]