diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py index 724f074..752a528 100644 --- a/nodes/feature_extractor.py +++ b/nodes/feature_extractor.py @@ -13,13 +13,29 @@ _PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__)) _MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env") _MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python") +def _jax_package(): + """Return the correct jax extra for the current CUDA version.""" + try: + import torch + if torch.cuda.is_available(): + cuda_ver = torch.version.cuda or "" + major = int(cuda_ver.split(".")[0]) if cuda_ver else 0 + if major >= 13: + return "jax[cuda13]" + elif major >= 12: + return "jax[cuda12]" + except Exception: + pass + return "jax" # CPU fallback + + _EXTRACT_PACKAGES = [ "torch", "torchaudio", "torchvision", # TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+ "tensorflow-cpu>=2.16.0", - # jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed) - "jax[cuda13]", "flax", - "transformers", "decord", "einops", "numpy", "mediapy", + # jax CUDA extra is resolved at install time based on detected CUDA version + _jax_package(), "flax", + "transformers", "decord", "einops", "numpy", "git+https://github.com/google-deepmind/videoprism.git", ] @@ -70,11 +86,12 @@ def _ensure_extract_env(): return _MANAGED_PYTHON -def _hash_inputs(video_tensor, cot_text): +def _hash_inputs(video_tensor, cot_text, fps): """Create a hash of the inputs for caching.""" h = hashlib.sha256() h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed h.update(cot_text.encode()) + h.update(str(fps).encode()) # fps affects frame sampling — must be part of the key return h.hexdigest()[:16] @@ -115,6 +132,10 @@ class PrismAudioFeatureExtractor: if video_info is not None: fps = video_info["loaded_fps"] + if not caption_cot.strip(): + print("[PrismAudio] Warning: caption_cot is empty — text features will be degenerate. " + "Provide a descriptive chain-of-thought caption for best results.", flush=True) + # Resolve python binary if python_env == "comfyui_env": print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. " @@ -129,7 +150,7 @@ class PrismAudioFeatureExtractor: os.makedirs(cache_dir, exist_ok=True) # Check cache - cache_hash = _hash_inputs(video, caption_cot) + cache_hash = _hash_inputs(video, caption_cot, fps) cached_path = os.path.join(cache_dir, f"{cache_hash}.npz") if os.path.exists(cached_path): print(f"[PrismAudio] Using cached features: {cached_path}") diff --git a/scripts/extract_features.py b/scripts/extract_features.py index 00303a5..45e8bd3 100755 --- a/scripts/extract_features.py +++ b/scripts/extract_features.py @@ -85,12 +85,13 @@ def main(): duration = total_frames / fps print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) - clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] + clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))] clip_indices = [min(i, total_frames - 1) for i in clip_indices] clip_frames = all_frames[clip_indices] print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) - sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] + # Synchformer processes in segments of 8; ensure at least 8 frames + sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))] sync_indices = [min(i, total_frames - 1) for i in sync_indices] sync_frames = all_frames[sync_indices] print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) @@ -102,12 +103,13 @@ def main(): duration = total_frames / fps print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) - clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] + clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))] clip_indices = [min(i, total_frames - 1) for i in clip_indices] clip_frames = vr.get_batch(clip_indices).asnumpy() print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) - sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] + # Synchformer processes in segments of 8; ensure at least 8 frames + sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))] sync_indices = [min(i, total_frames - 1) for i in sync_indices] sync_frames = vr.get_batch(sync_indices).asnumpy() print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)