From 2f626d8a96651c69ac46632e4728c2ad2fe0fc6c Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Mar 2026 10:37:02 +0100 Subject: [PATCH] fix: use videoprism_lvt_public_v1_large with joint video-text forward The wrong model (videoprism_public_v1_large, vision-only) was used, causing V2A audio distortion. Switch to the LvT variant which has a text tower, pass CoT captions for joint encoding, and extract per-frame features from outputs['frame_embeddings'] (L2-normalized, [T, 1024]) instead of manually averaging spatial patches. Co-Authored-By: Claude Sonnet 4.6 --- data_utils/v2a_utils/feature_utils_288.py | 54 ++++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index 8ae576f..245d5e7 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -20,6 +20,7 @@ class FeaturesUtils: self._t5_encoder = None self._vp_model = None self._vp_state = None + self._vp_text_tokenizer = None self._sync_model = None self._synchformer_ckpt = synchformer_ckpt @@ -70,26 +71,32 @@ class FeaturesUtils: return from videoprism import models as vp import jax - print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...") - self._vp_model = vp.get_model("videoprism_public_v1_large") - self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large") + model_name = "videoprism_lvt_public_v1_large" + print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...") + self._vp_model = vp.get_model(model_name) + self._vp_state = vp.load_pretrained_weights(model_name) + self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en") + jax_dev = jax.devices()[0] self._jax_forward = jax.jit( - lambda x: self._vp_model.apply(self._vp_state, x, train=False) + lambda x, y, z: self._vp_model.apply( + self._vp_state, x, y, z, train=False, return_intermediate=True + ), + device=jax_dev, ) def encode_video_and_text_with_videoprism(self, clip_input, texts): """ Args: clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1] - texts: list of str (unused — VideoPrism is vision-only; - global_text_features returned as zeros placeholder) + texts: list of str — CoT captions, passed to VideoPrism LvT text tower Returns: global_video_features: Tensor [1, D] - video_features: Tensor [T, D] - global_text_features: Tensor [1, D] (zeros — no text tower) + video_features: Tensor [T, D] — per-frame L2-normalized embeddings + global_text_features: Tensor [1, D] """ self._ensure_videoprism() import jax.numpy as jnp + from videoprism import models as vp # Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array frames = clip_input.squeeze(0) # [T, C, H, W] @@ -98,24 +105,27 @@ class FeaturesUtils: frames_np = frames.cpu().numpy().astype(np.float32) frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C] - embeddings, _ = self._jax_forward(frames_jax) # [1, T*N, D] + # Tokenize text (padding value 1.0 = pad, 0.0 = real token) + text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts) - # Convert back to torch - embeddings_np = np.array(embeddings) # [1, T*N, D] - emb = torch.from_numpy(embeddings_np).to(self.device) # [1, T*N, D] + # Joint video+text forward with intermediate outputs + video_embeddings, text_embeddings, outputs = self._jax_forward( + frames_jax, text_ids, text_paddings + ) - T = frames_np.shape[0] - D = emb.shape[-1] - N = emb.shape[1] // T # spatial patches per frame + # Per-frame features: [B, T, 1024] L2-normalized + frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024] + per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024] - # Global video: mean over all tokens - global_video = emb.mean(dim=1) # [1, D] + # Global video embedding: [1024] → [1, 1024] + global_video = torch.from_numpy( + np.array(video_embeddings[0]) + ).unsqueeze(0).to(self.device) # [1, 1024] - # Per-frame: mean over spatial patches - per_frame = emb.view(1, T, N, D).mean(dim=2).squeeze(0) # [T, D] - - # Text features: zeros (VideoPrism public model is vision-only) - global_text = torch.zeros(1, D, device=self.device) + # Global text embedding: [1024] → [1, 1024] + global_text = torch.from_numpy( + np.array(text_embeddings[0]) + ).unsqueeze(0).to(self.device) # [1, 1024] return global_video, per_frame, global_text