diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py index 4b0117e..994f31c 100644 --- a/nodes/feature_extractor.py +++ b/nodes/feature_extractor.py @@ -78,29 +78,14 @@ def _hash_inputs(video_tensor, cot_text): return h.hexdigest()[:16] -def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30): - """Save ComfyUI IMAGE tensor [T,H,W,C] to MP4 by piping raw RGB to ffmpeg. +def _save_frames_to_npy(video_tensor, output_path): + """Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8. - Avoids intermediate PNG files — frames are streamed directly to ffmpeg stdin. + Lossless — avoids H.264 encode/decode roundtrip. """ + import numpy as np frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8") - T, H, W, C = frames_np.shape - - result = subprocess.run( - [ - "ffmpeg", "-y", - "-f", "rawvideo", "-vcodec", "rawvideo", - "-s", f"{W}x{H}", "-pix_fmt", "rgb24", - "-r", str(fps), - "-i", "pipe:0", - "-c:v", "libx264", "-pix_fmt", "yuv420p", - output_path, - ], - input=frames_np.tobytes(), - capture_output=True, - ) - if result.returncode != 0: - raise RuntimeError(f"[PrismAudio] ffmpeg failed:\n{result.stderr.decode()}") + np.save(output_path, frames_np) class PrismAudioFeatureExtractor: @@ -143,15 +128,15 @@ class PrismAudioFeatureExtractor: loader = PrismAudioFeatureLoader() return loader.load_features(cached_path) - # Save video to temp file + # Save frames to temp file (lossless .npy, no codec roundtrip) import time t0 = time.perf_counter() frames = video.shape[0] - print(f"[PrismAudio] Converting {frames} frames to MP4 (fps={fps})...", flush=True) - with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: + print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True) + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp: tmp_video = tmp.name - _save_video_tensor_to_mp4(video, tmp_video, fps=fps) - print(f"[PrismAudio] MP4 ready in {time.perf_counter() - t0:.1f}s ({tmp_video})", flush=True) + _save_frames_to_npy(video, tmp_video) + print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True) # Build subprocess command script_path = os.path.join( @@ -165,6 +150,7 @@ class PrismAudioFeatureExtractor: "--video", tmp_video, "--cot_text", caption_cot, "--output", cached_path, + "--source_fps", str(fps), ] # Auto-resolve synchformer checkpoint from the prismaudio models dir if not synchformer_ckpt: diff --git a/scripts/extract_features.py b/scripts/extract_features.py index 6098710..00303a5 100755 --- a/scripts/extract_features.py +++ b/scripts/extract_features.py @@ -42,6 +42,7 @@ def main(): parser.add_argument("--output", required=True, help="Output .npz path") parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint") parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON") + parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)") parser.add_argument("--clip_fps", type=float, default=4.0) parser.add_argument("--clip_size", type=int, default=288) parser.add_argument("--sync_fps", type=float, default=25.0) @@ -64,7 +65,6 @@ def main(): t0 = _step(1, 6, "importing dependencies") from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils import torchvision.transforms as T - from decord import VideoReader, cpu _done(t0) # ------------------------------------------------------------------ @@ -78,16 +78,39 @@ def main(): # ------------------------------------------------------------------ t0 = _step(3, 6, "reading and preprocessing video") - vr = VideoReader(args.video, ctx=cpu(0)) - fps = vr.get_avg_fps() - total_frames = len(vr) - duration = total_frames / fps - print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) + if args.video.endswith(".npy"): + all_frames = np.load(args.video) # [T, H, W, C] uint8 + fps = args.source_fps + total_frames = all_frames.shape[0] + duration = total_frames / fps + print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) - clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] - clip_indices = [min(i, total_frames - 1) for i in clip_indices] - clip_frames = vr.get_batch(clip_indices).asnumpy() - print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) + clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] + clip_indices = [min(i, total_frames - 1) for i in clip_indices] + clip_frames = all_frames[clip_indices] + print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) + + sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] + sync_indices = [min(i, total_frames - 1) for i in sync_indices] + sync_frames = all_frames[sync_indices] + print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) + else: + from decord import VideoReader, cpu + vr = VideoReader(args.video, ctx=cpu(0)) + fps = vr.get_avg_fps() + total_frames = len(vr) + duration = total_frames / fps + print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) + + clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] + clip_indices = [min(i, total_frames - 1) for i in clip_indices] + clip_frames = vr.get_batch(clip_indices).asnumpy() + print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) + + sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] + sync_indices = [min(i, total_frames - 1) for i in sync_indices] + sync_frames = vr.get_batch(sync_indices).asnumpy() + print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) clip_transform = T.Compose([ T.ToPILImage(), @@ -98,11 +121,6 @@ def main(): ]) clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) - sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] - sync_indices = [min(i, total_frames - 1) for i in sync_indices] - sync_frames = vr.get_batch(sync_indices).asnumpy() - print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) - sync_transform = T.Compose([ T.ToPILImage(), T.Resize(args.sync_size),