fix: correct VideoPrism import (videoprism.models, not videoprism); add flax dep

videoprism/__init__.py is empty — API lives in videoprism.models.
Fix: from videoprism import models as vp (not import videoprism as vp).
Also add flax to managed venv packages (required by videoprism Flax model).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 20:38:00 +01:00
parent 0f46e8359d
commit b1a2ee594e
2 changed files with 2 additions and 3 deletions
+1 -2
View File
@@ -68,12 +68,11 @@ class FeaturesUtils:
def _ensure_videoprism(self): def _ensure_videoprism(self):
if self._vp_model is not None: if self._vp_model is not None:
return return
import videoprism as vp from videoprism import models as vp
import jax import jax
print("[FeaturesUtils] Loading VideoPrism...") print("[FeaturesUtils] Loading VideoPrism...")
self._vp_model = vp.get_model("videoprism_public_v1_base") self._vp_model = vp.get_model("videoprism_public_v1_base")
self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base") self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base")
import jax
self._jax_forward = jax.jit( self._jax_forward = jax.jit(
lambda x: self._vp_model.apply(self._vp_state, x, train=False) lambda x: self._vp_model.apply(self._vp_state, x, train=False)
) )
+1 -1
View File
@@ -18,7 +18,7 @@ _EXTRACT_PACKAGES = [
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+ # TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
"tensorflow-cpu>=2.16.0", "tensorflow-cpu>=2.16.0",
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed) # jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
"jax[cuda13]", "jax[cuda13]", "flax",
"transformers", "decord", "einops", "numpy", "mediapy", "transformers", "decord", "einops", "numpy", "mediapy",
"git+https://github.com/google-deepmind/videoprism.git", "git+https://github.com/google-deepmind/videoprism.git",
] ]