fix: use videoprism_lvt_public_v1_large with joint video-text forward
The wrong model (videoprism_public_v1_large, vision-only) was used, causing V2A audio distortion. Switch to the LvT variant which has a text tower, pass CoT captions for joint encoding, and extract per-frame features from outputs['frame_embeddings'] (L2-normalized, [T, 1024]) instead of manually averaging spatial patches. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -20,6 +20,7 @@ class FeaturesUtils:
|
|||||||
self._t5_encoder = None
|
self._t5_encoder = None
|
||||||
self._vp_model = None
|
self._vp_model = None
|
||||||
self._vp_state = None
|
self._vp_state = None
|
||||||
|
self._vp_text_tokenizer = None
|
||||||
self._sync_model = None
|
self._sync_model = None
|
||||||
|
|
||||||
self._synchformer_ckpt = synchformer_ckpt
|
self._synchformer_ckpt = synchformer_ckpt
|
||||||
@@ -70,26 +71,32 @@ class FeaturesUtils:
|
|||||||
return
|
return
|
||||||
from videoprism import models as vp
|
from videoprism import models as vp
|
||||||
import jax
|
import jax
|
||||||
print("[FeaturesUtils] Loading VideoPrism large (1024-dim, required by prismaudio conditioner)...")
|
model_name = "videoprism_lvt_public_v1_large"
|
||||||
self._vp_model = vp.get_model("videoprism_public_v1_large")
|
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
|
||||||
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_large")
|
self._vp_model = vp.get_model(model_name)
|
||||||
|
self._vp_state = vp.load_pretrained_weights(model_name)
|
||||||
|
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
|
||||||
|
jax_dev = jax.devices()[0]
|
||||||
self._jax_forward = jax.jit(
|
self._jax_forward = jax.jit(
|
||||||
lambda x: self._vp_model.apply(self._vp_state, x, train=False)
|
lambda x, y, z: self._vp_model.apply(
|
||||||
|
self._vp_state, x, y, z, train=False, return_intermediate=True
|
||||||
|
),
|
||||||
|
device=jax_dev,
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||||
texts: list of str (unused — VideoPrism is vision-only;
|
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
|
||||||
global_text_features returned as zeros placeholder)
|
|
||||||
Returns:
|
Returns:
|
||||||
global_video_features: Tensor [1, D]
|
global_video_features: Tensor [1, D]
|
||||||
video_features: Tensor [T, D]
|
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
|
||||||
global_text_features: Tensor [1, D] (zeros — no text tower)
|
global_text_features: Tensor [1, D]
|
||||||
"""
|
"""
|
||||||
self._ensure_videoprism()
|
self._ensure_videoprism()
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from videoprism import models as vp
|
||||||
|
|
||||||
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
||||||
frames = clip_input.squeeze(0) # [T, C, H, W]
|
frames = clip_input.squeeze(0) # [T, C, H, W]
|
||||||
@@ -98,24 +105,27 @@ class FeaturesUtils:
|
|||||||
frames_np = frames.cpu().numpy().astype(np.float32)
|
frames_np = frames.cpu().numpy().astype(np.float32)
|
||||||
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
||||||
|
|
||||||
embeddings, _ = self._jax_forward(frames_jax) # [1, T*N, D]
|
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
|
||||||
|
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
|
||||||
|
|
||||||
# Convert back to torch
|
# Joint video+text forward with intermediate outputs
|
||||||
embeddings_np = np.array(embeddings) # [1, T*N, D]
|
video_embeddings, text_embeddings, outputs = self._jax_forward(
|
||||||
emb = torch.from_numpy(embeddings_np).to(self.device) # [1, T*N, D]
|
frames_jax, text_ids, text_paddings
|
||||||
|
)
|
||||||
|
|
||||||
T = frames_np.shape[0]
|
# Per-frame features: [B, T, 1024] L2-normalized
|
||||||
D = emb.shape[-1]
|
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
|
||||||
N = emb.shape[1] // T # spatial patches per frame
|
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
|
||||||
|
|
||||||
# Global video: mean over all tokens
|
# Global video embedding: [1024] → [1, 1024]
|
||||||
global_video = emb.mean(dim=1) # [1, D]
|
global_video = torch.from_numpy(
|
||||||
|
np.array(video_embeddings[0])
|
||||||
|
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||||
|
|
||||||
# Per-frame: mean over spatial patches
|
# Global text embedding: [1024] → [1, 1024]
|
||||||
per_frame = emb.view(1, T, N, D).mean(dim=2).squeeze(0) # [T, D]
|
global_text = torch.from_numpy(
|
||||||
|
np.array(text_embeddings[0])
|
||||||
# Text features: zeros (VideoPrism public model is vision-only)
|
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||||
global_text = torch.zeros(1, D, device=self.device)
|
|
||||||
|
|
||||||
return global_video, per_frame, global_text
|
return global_video, per_frame, global_text
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user