fix: use videoprism_lvt_public_v1_large with joint video-text forward

The wrong model (videoprism_public_v1_large, vision-only) was used,
causing V2A audio distortion. Switch to the LvT variant which has a
text tower, pass CoT captions for joint encoding, and extract per-frame
features from outputs['frame_embeddings'] (L2-normalized, [T, 1024])
instead of manually averaging spatial patches.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-28 10:37:02 +01:00
parent 1d8b9b59e0
commit 2f626d8a96
+32 -22
View File
@@ -20,6 +20,7 @@ class FeaturesUtils:
self._t5_encoder = None self._t5_encoder = None
self._vp_model = None self._vp_model = None
self._vp_state = None self._vp_state = None
self._vp_text_tokenizer = None
self._sync_model = None self._sync_model = None
self._synchformer_ckpt = synchformer_ckpt self._synchformer_ckpt = synchformer_ckpt
@@ -70,26 +71,32 @@ class FeaturesUtils:
return return
from videoprism import models as vp from videoprism import models as vp
import jax import jax
print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...") model_name = "videoprism_lvt_public_v1_large"
self._vp_model = vp.get_model("videoprism_public_v1_large") print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large") self._vp_model = vp.get_model(model_name)
self._vp_state = vp.load_pretrained_weights(model_name)
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
jax_dev = jax.devices()[0]
self._jax_forward = jax.jit( self._jax_forward = jax.jit(
lambda x: self._vp_model.apply(self._vp_state, x, train=False) lambda x, y, z: self._vp_model.apply(
self._vp_state, x, y, z, train=False, return_intermediate=True
),
device=jax_dev,
) )
def encode_video_and_text_with_videoprism(self, clip_input, texts): def encode_video_and_text_with_videoprism(self, clip_input, texts):
""" """
Args: Args:
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1] clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
texts: list of str (unused VideoPrism is vision-only; texts: list of str — CoT captions, passed to VideoPrism LvT text tower
global_text_features returned as zeros placeholder)
Returns: Returns:
global_video_features: Tensor [1, D] global_video_features: Tensor [1, D]
video_features: Tensor [T, D] video_features: Tensor [T, D] — per-frame L2-normalized embeddings
global_text_features: Tensor [1, D] (zeros — no text tower) global_text_features: Tensor [1, D]
""" """
self._ensure_videoprism() self._ensure_videoprism()
import jax.numpy as jnp import jax.numpy as jnp
from videoprism import models as vp
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array # Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
frames = clip_input.squeeze(0) # [T, C, H, W] frames = clip_input.squeeze(0) # [T, C, H, W]
@@ -98,24 +105,27 @@ class FeaturesUtils:
frames_np = frames.cpu().numpy().astype(np.float32) frames_np = frames.cpu().numpy().astype(np.float32)
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C] frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
embeddings, _ = self._jax_forward(frames_jax) # [1, T*N, D] # Tokenize text (padding value 1.0 = pad, 0.0 = real token)
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
# Convert back to torch # Joint video+text forward with intermediate outputs
embeddings_np = np.array(embeddings) # [1, T*N, D] video_embeddings, text_embeddings, outputs = self._jax_forward(
emb = torch.from_numpy(embeddings_np).to(self.device) # [1, T*N, D] frames_jax, text_ids, text_paddings
)
T = frames_np.shape[0] # Per-frame features: [B, T, 1024] L2-normalized
D = emb.shape[-1] frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
N = emb.shape[1] // T # spatial patches per frame per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
# Global video: mean over all tokens # Global video embedding: [1024] → [1, 1024]
global_video = emb.mean(dim=1) # [1, D] global_video = torch.from_numpy(
np.array(video_embeddings[0])
).unsqueeze(0).to(self.device) # [1, 1024]
# Per-frame: mean over spatial patches # Global text embedding: [1024] → [1, 1024]
per_frame = emb.view(1, T, N, D).mean(dim=2).squeeze(0) # [T, D] global_text = torch.from_numpy(
np.array(text_embeddings[0])
# Text features: zeros (VideoPrism public model is vision-only) ).unsqueeze(0).to(self.device) # [1, 1024]
global_text = torch.zeros(1, D, device=self.device)
return global_video, per_frame, global_text return global_video, per_frame, global_text