feat: add per-step timing to feature extraction logs

Each step now prints elapsed seconds on completion.
Total time printed at the end to identify bottlenecks.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:13:42 +01:00
parent 63bd999dfa
commit ca87c41a2e
+30 -18
View File
@@ -10,6 +10,7 @@ Usage:
import argparse import argparse
import os import os
import sys import sys
import time
import numpy as np import numpy as np
import torch import torch
@@ -20,7 +21,21 @@ if _PLUGIN_DIR not in sys.path:
sys.path.insert(0, _PLUGIN_DIR) sys.path.insert(0, _PLUGIN_DIR)
def _step(n, total, label):
"""Print step header and return start time."""
print(f"[extract] Step {n}/{total}{label}...", flush=True)
return time.perf_counter()
def _done(t0, extra=""):
elapsed = time.perf_counter() - t0
suffix = f" {extra}" if extra else ""
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
def main(): def main():
t_total = time.perf_counter()
parser = argparse.ArgumentParser(description="PrismAudio feature extraction") parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
parser.add_argument("--video", required=True, help="Path to input video") parser.add_argument("--video", required=True, help="Path to input video")
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description") parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
@@ -46,23 +61,23 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("[extract] Step 1/6 — importing dependencies...", flush=True) t0 = _step(1, 6, "importing dependencies")
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
import torchvision.transforms as T import torchvision.transforms as T
from decord import VideoReader, cpu from decord import VideoReader, cpu
print("[extract] Step 1/6 — done", flush=True) _done(t0)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("[extract] Step 2/6 — loading models (T5, VideoPrism, Synchformer)...", flush=True) t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
feat_utils = FeaturesUtils( feat_utils = FeaturesUtils(
vae_config_path=args.vae_config, vae_config_path=args.vae_config,
synchformer_ckpt=args.synchformer_ckpt, synchformer_ckpt=args.synchformer_ckpt,
device=device, device=device,
) )
print("[extract] Step 2/6 — done", flush=True) _done(t0)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("[extract] Step 3/6 — reading and preprocessing video...", flush=True) t0 = _step(3, 6, "reading and preprocessing video")
vr = VideoReader(args.video, ctx=cpu(0)) vr = VideoReader(args.video, ctx=cpu(0))
fps = vr.get_avg_fps() fps = vr.get_avg_fps()
total_frames = len(vr) total_frames = len(vr)
@@ -96,30 +111,26 @@ def main():
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]) ])
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
print("[extract] Step 3/6 — done", flush=True) _done(t0)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("[extract] Step 4/6 — encoding text with T5-Gemma...", flush=True) t0 = _step(4, 6, "encoding text with T5-Gemma")
text_features = feat_utils.encode_t5_text([args.cot_text]) text_features = feat_utils.encode_t5_text([args.cot_text])
print(f"[extract] text_features shape: {tuple(text_features.shape)}", flush=True) _done(t0, f"shape={tuple(text_features.shape)}")
print("[extract] Step 4/6 — done", flush=True)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("[extract] Step 5/6 — encoding video with VideoPrism...", flush=True) t0 = _step(5, 6, "encoding video with VideoPrism")
global_video_features, video_features, global_text_features = \ global_video_features, video_features, global_text_features = \
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
print(f"[extract] global_video_features : {tuple(global_video_features.shape)}", flush=True) _done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
print(f"[extract] video_features : {tuple(video_features.shape)}", flush=True)
print(f"[extract] global_text_features : {tuple(global_text_features.shape)}", flush=True)
print("[extract] Step 5/6 — done", flush=True)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("[extract] Step 6/6 — encoding video with Synchformer...", flush=True) t0 = _step(6, 6, "encoding video with Synchformer")
sync_features = feat_utils.encode_video_with_sync(sync_input) sync_features = feat_utils.encode_video_with_sync(sync_input)
print(f"[extract] sync_features shape: {tuple(sync_features.shape)}", flush=True) _done(t0, f"shape={tuple(sync_features.shape)}")
print("[extract] Step 6/6 — done", flush=True)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
t0 = time.perf_counter()
print(f"[extract] Saving features to {args.output} ...", flush=True) print(f"[extract] Saving features to {args.output} ...", flush=True)
np.savez( np.savez(
args.output, args.output,
@@ -131,7 +142,8 @@ def main():
caption_cot=args.cot_text, caption_cot=args.cot_text,
duration=duration, duration=duration,
) )
print(f"[extract] Done — features saved to {args.output}", flush=True) print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
if __name__ == "__main__": if __name__ == "__main__":