From b1a2ee594ef134c1cfa1120a8c9c46ed78d7282b Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 20:38:00 +0100 Subject: [PATCH] fix: correct VideoPrism import (videoprism.models, not videoprism); add flax dep MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit videoprism/__init__.py is empty — API lives in videoprism.models. Fix: from videoprism import models as vp (not import videoprism as vp). Also add flax to managed venv packages (required by videoprism Flax model). Co-Authored-By: Claude Sonnet 4.6 --- data_utils/v2a_utils/feature_utils_288.py | 3 +-- nodes/feature_extractor.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index b2da524..5ddd3f9 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -68,12 +68,11 @@ class FeaturesUtils: def _ensure_videoprism(self): if self._vp_model is not None: return - import videoprism as vp + from videoprism import models as vp import jax print("[FeaturesUtils] Loading VideoPrism...") self._vp_model = vp.get_model("videoprism_public_v1_base") self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base") - import jax self._jax_forward = jax.jit( lambda x: self._vp_model.apply(self._vp_state, x, train=False) ) diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py index f15536c..3d750ca 100644 --- a/nodes/feature_extractor.py +++ b/nodes/feature_extractor.py @@ -18,7 +18,7 @@ _EXTRACT_PACKAGES = [ # TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+ "tensorflow-cpu>=2.16.0", # jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed) - "jax[cuda13]", + "jax[cuda13]", "flax", "transformers", "decord", "einops", "numpy", "mediapy", "git+https://github.com/google-deepmind/videoprism.git", ]