fix: feature extractor CUDA detection, cache correctness, and short-video crash

- Detect CUDA version at venv creation time and install matching jax[cuda12/13]
  instead of hardcoded jax[cuda13] — was broken on CUDA 12.x (most systems)
- Include fps in cache hash: same video+caption at different fps previously
  returned stale cached features with wrong frame sampling
- Guard frame index lists with max(1,...)/max(8,...) to prevent torch.stack([])
  crash on very short input clips; sync minimum is 8 to match Synchformer's
  segment size requirement
- Remove mediapy from managed venv packages — not imported anywhere
- Warn when caption_cot is empty (produces degenerate text features)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-28 16:00:05 +01:00
parent 4f40e15db3
commit e49f760b77
2 changed files with 32 additions and 9 deletions
+6 -4
View File
@@ -85,12 +85,13 @@ def main():
duration = total_frames / fps
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = all_frames[clip_indices]
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
# Synchformer processes in segments of 8; ensure at least 8 frames
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = all_frames[sync_indices]
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
@@ -102,12 +103,13 @@ def main():
duration = total_frames / fps
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = vr.get_batch(clip_indices).asnumpy()
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
# Synchformer processes in segments of 8; ensure at least 8 frames
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = vr.get_batch(sync_indices).asnumpy()
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)