diff --git a/scripts/extract_features.py b/scripts/extract_features.py index 9ed844b..d82aff5 100755 --- a/scripts/extract_features.py +++ b/scripts/extract_features.py @@ -123,11 +123,11 @@ def main(): print(f"[extract] Saving features to {args.output} ...", flush=True) np.savez( args.output, - video_features=video_features.cpu().numpy(), - global_video_features=global_video_features.cpu().numpy(), - text_features=text_features.cpu().numpy(), - global_text_features=global_text_features.cpu().numpy(), - sync_features=sync_features.cpu().numpy(), + video_features=video_features.cpu().float().numpy(), + global_video_features=global_video_features.cpu().float().numpy(), + text_features=text_features.cpu().float().numpy(), + global_text_features=global_text_features.cpu().float().numpy(), + sync_features=sync_features.cpu().float().numpy(), caption_cot=args.cot_text, duration=duration, )