diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index b2da524..5ddd3f9 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -68,12 +68,11 @@ class FeaturesUtils: def _ensure_videoprism(self): if self._vp_model is not None: return - import videoprism as vp + from videoprism import models as vp import jax print("[FeaturesUtils] Loading VideoPrism...") self._vp_model = vp.get_model("videoprism_public_v1_base") self._vp_state = vp.load_pretrained_weights("videoprism_public_v1_base") - import jax self._jax_forward = jax.jit( lambda x: self._vp_model.apply(self._vp_state, x, train=False) ) diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py index f15536c..3d750ca 100644 --- a/nodes/feature_extractor.py +++ b/nodes/feature_extractor.py @@ -18,7 +18,7 @@ _EXTRACT_PACKAGES = [ # TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+ "tensorflow-cpu>=2.16.0", # jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed) - "jax[cuda13]", + "jax[cuda13]", "flax", "transformers", "decord", "einops", "numpy", "mediapy", "git+https://github.com/google-deepmind/videoprism.git", ]