diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py index 16a4f9d..f15536c 100644 --- a/nodes/feature_extractor.py +++ b/nodes/feature_extractor.py @@ -17,7 +17,8 @@ _EXTRACT_PACKAGES = [ "torch", "torchaudio", "torchvision", # TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+ "tensorflow-cpu>=2.16.0", - "jax[cpu]", "jaxlib", + # jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed) + "jax[cuda13]", "transformers", "decord", "einops", "numpy", "mediapy", "git+https://github.com/google-deepmind/videoprism.git", ]