perf: replace MP4 encode/decode with lossless .npy frame transfer
Saves frames as uint8 .npy instead of H.264 MP4, eliminating the lossy codec roundtrip. extract_features.py loads .npy directly and skips decord when given a numpy file. Passes --source_fps for correct temporal sampling. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+33
-15
@@ -42,6 +42,7 @@ def main():
|
||||
parser.add_argument("--output", required=True, help="Output .npz path")
|
||||
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
||||
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
||||
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
|
||||
parser.add_argument("--clip_fps", type=float, default=4.0)
|
||||
parser.add_argument("--clip_size", type=int, default=288)
|
||||
parser.add_argument("--sync_fps", type=float, default=25.0)
|
||||
@@ -64,7 +65,6 @@ def main():
|
||||
t0 = _step(1, 6, "importing dependencies")
|
||||
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
||||
import torchvision.transforms as T
|
||||
from decord import VideoReader, cpu
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -78,16 +78,39 @@ def main():
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(3, 6, "reading and preprocessing video")
|
||||
vr = VideoReader(args.video, ctx=cpu(0))
|
||||
fps = vr.get_avg_fps()
|
||||
total_frames = len(vr)
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
if args.video.endswith(".npy"):
|
||||
all_frames = np.load(args.video) # [T, H, W, C] uint8
|
||||
fps = args.source_fps
|
||||
total_frames = all_frames.shape[0]
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = all_frames[clip_indices]
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = all_frames[sync_indices]
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
else:
|
||||
from decord import VideoReader, cpu
|
||||
vr = VideoReader(args.video, ctx=cpu(0))
|
||||
fps = vr.get_avg_fps()
|
||||
total_frames = len(vr)
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
|
||||
clip_transform = T.Compose([
|
||||
T.ToPILImage(),
|
||||
@@ -98,11 +121,6 @@ def main():
|
||||
])
|
||||
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
|
||||
sync_transform = T.Compose([
|
||||
T.ToPILImage(),
|
||||
T.Resize(args.sync_size),
|
||||
|
||||
Reference in New Issue
Block a user