fix: feature extractor CUDA detection, cache correctness, and short-video crash
- Detect CUDA version at venv creation time and install matching jax[cuda12/13] instead of hardcoded jax[cuda13] — was broken on CUDA 12.x (most systems) - Include fps in cache hash: same video+caption at different fps previously returned stale cached features with wrong frame sampling - Guard frame index lists with max(1,...)/max(8,...) to prevent torch.stack([]) crash on very short input clips; sync minimum is 8 to match Synchformer's segment size requirement - Remove mediapy from managed venv packages — not imported anywhere - Warn when caption_cot is empty (produces degenerate text features) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,13 +13,29 @@ _PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
|||||||
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
||||||
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
||||||
|
|
||||||
|
def _jax_package():
|
||||||
|
"""Return the correct jax extra for the current CUDA version."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
cuda_ver = torch.version.cuda or ""
|
||||||
|
major = int(cuda_ver.split(".")[0]) if cuda_ver else 0
|
||||||
|
if major >= 13:
|
||||||
|
return "jax[cuda13]"
|
||||||
|
elif major >= 12:
|
||||||
|
return "jax[cuda12]"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "jax" # CPU fallback
|
||||||
|
|
||||||
|
|
||||||
_EXTRACT_PACKAGES = [
|
_EXTRACT_PACKAGES = [
|
||||||
"torch", "torchaudio", "torchvision",
|
"torch", "torchaudio", "torchvision",
|
||||||
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
||||||
"tensorflow-cpu>=2.16.0",
|
"tensorflow-cpu>=2.16.0",
|
||||||
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
|
# jax CUDA extra is resolved at install time based on detected CUDA version
|
||||||
"jax[cuda13]", "flax",
|
_jax_package(), "flax",
|
||||||
"transformers", "decord", "einops", "numpy", "mediapy",
|
"transformers", "decord", "einops", "numpy",
|
||||||
"git+https://github.com/google-deepmind/videoprism.git",
|
"git+https://github.com/google-deepmind/videoprism.git",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -70,11 +86,12 @@ def _ensure_extract_env():
|
|||||||
return _MANAGED_PYTHON
|
return _MANAGED_PYTHON
|
||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, cot_text):
|
def _hash_inputs(video_tensor, cot_text, fps):
|
||||||
"""Create a hash of the inputs for caching."""
|
"""Create a hash of the inputs for caching."""
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||||
h.update(cot_text.encode())
|
h.update(cot_text.encode())
|
||||||
|
h.update(str(fps).encode()) # fps affects frame sampling — must be part of the key
|
||||||
return h.hexdigest()[:16]
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
@@ -115,6 +132,10 @@ class PrismAudioFeatureExtractor:
|
|||||||
if video_info is not None:
|
if video_info is not None:
|
||||||
fps = video_info["loaded_fps"]
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
|
if not caption_cot.strip():
|
||||||
|
print("[PrismAudio] Warning: caption_cot is empty — text features will be degenerate. "
|
||||||
|
"Provide a descriptive chain-of-thought caption for best results.", flush=True)
|
||||||
|
|
||||||
# Resolve python binary
|
# Resolve python binary
|
||||||
if python_env == "comfyui_env":
|
if python_env == "comfyui_env":
|
||||||
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
||||||
@@ -129,7 +150,7 @@ class PrismAudioFeatureExtractor:
|
|||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
# Check cache
|
# Check cache
|
||||||
cache_hash = _hash_inputs(video, caption_cot)
|
cache_hash = _hash_inputs(video, caption_cot, fps)
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||||
if os.path.exists(cached_path):
|
if os.path.exists(cached_path):
|
||||||
print(f"[PrismAudio] Using cached features: {cached_path}")
|
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||||
|
|||||||
@@ -85,12 +85,13 @@ def main():
|
|||||||
duration = total_frames / fps
|
duration = total_frames / fps
|
||||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
clip_frames = all_frames[clip_indices]
|
clip_frames = all_frames[clip_indices]
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
sync_frames = all_frames[sync_indices]
|
sync_frames = all_frames[sync_indices]
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
@@ -102,12 +103,13 @@ def main():
|
|||||||
duration = total_frames / fps
|
duration = total_frames / fps
|
||||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user