fix: feature extractor CUDA detection, cache correctness, and short-video crash
- Detect CUDA version at venv creation time and install matching jax[cuda12/13] instead of hardcoded jax[cuda13] — was broken on CUDA 12.x (most systems) - Include fps in cache hash: same video+caption at different fps previously returned stale cached features with wrong frame sampling - Guard frame index lists with max(1,...)/max(8,...) to prevent torch.stack([]) crash on very short input clips; sync minimum is 8 to match Synchformer's segment size requirement - Remove mediapy from managed venv packages — not imported anywhere - Warn when caption_cot is empty (produces degenerate text features) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -85,12 +85,13 @@ def main():
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = all_frames[clip_indices]
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = all_frames[sync_indices]
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
@@ -102,12 +103,13 @@ def main():
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user