diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py index 6254766..db306f7 100644 --- a/nodes/feature_extractor.py +++ b/nodes/feature_extractor.py @@ -171,19 +171,20 @@ class PrismAudioFeatureExtractor: if synchformer_ckpt: cmd.extend(["--synchformer_ckpt", synchformer_ckpt]) - print(f"[PrismAudio] Extracting features via subprocess...") + print(f"[PrismAudio] Extracting features via subprocess (output streams live)...") try: + # capture_output=False: let stdout/stderr stream directly to ComfyUI logs result = subprocess.run( cmd, - capture_output=True, - text=True, + capture_output=False, timeout=600, # 10 minute timeout ) if result.returncode != 0: raise RuntimeError( - f"[PrismAudio] Feature extraction failed:\n{result.stderr}" + f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. " + "See output above for details." ) - print(result.stdout) + print("[PrismAudio] Feature extraction subprocess finished successfully.") finally: if os.path.exists(tmp_video): os.unlink(tmp_video) diff --git a/scripts/extract_features.py b/scripts/extract_features.py index c4f0ea5..9ed844b 100755 --- a/scripts/extract_features.py +++ b/scripts/extract_features.py @@ -33,34 +33,46 @@ def main(): parser.add_argument("--sync_size", type=int, default=224) args = parser.parse_args() + print(f"[extract] Python : {sys.executable}", flush=True) + print(f"[extract] Video : {args.video}", flush=True) + print(f"[extract] Output : {args.output}", flush=True) + print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True) + if not os.path.exists(args.video): - print(f"Error: Video not found: {args.video}") + print(f"[extract] ERROR: video not found: {args.video}", flush=True) sys.exit(1) - # Import feature extraction utils (requires JAX/TF) + print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # ------------------------------------------------------------------ + print("[extract] Step 1/6 — importing dependencies...", flush=True) from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils import torchvision.transforms as T from decord import VideoReader, cpu + print("[extract] Step 1/6 — done", flush=True) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Initialize feature extractor + # ------------------------------------------------------------------ + print("[extract] Step 2/6 — loading models (T5, VideoPrism, Synchformer)...", flush=True) feat_utils = FeaturesUtils( vae_config_path=args.vae_config, synchformer_ckpt=args.synchformer_ckpt, device=device, ) + print("[extract] Step 2/6 — done", flush=True) - # Load and preprocess video + # ------------------------------------------------------------------ + print("[extract] Step 3/6 — reading and preprocessing video...", flush=True) vr = VideoReader(args.video, ctx=cpu(0)) fps = vr.get_avg_fps() total_frames = len(vr) duration = total_frames / fps + print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) - # Extract CLIP frames (4fps, 288x288) clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] clip_indices = [min(i, total_frames - 1) for i in clip_indices] clip_frames = vr.get_batch(clip_indices).asnumpy() + print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) clip_transform = T.Compose([ T.ToPILImage(), @@ -71,10 +83,10 @@ def main(): ]) clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) - # Extract Sync frames (25fps, 224x224) sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] sync_indices = [min(i, total_frames - 1) for i in sync_indices] sync_frames = vr.get_batch(sync_indices).asnumpy() + print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) sync_transform = T.Compose([ T.ToPILImage(), @@ -84,19 +96,31 @@ def main(): T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) + print("[extract] Step 3/6 — done", flush=True) - # Extract features - print("[PrismAudio] Encoding text with T5-Gemma...") + # ------------------------------------------------------------------ + print("[extract] Step 4/6 — encoding text with T5-Gemma...", flush=True) text_features = feat_utils.encode_t5_text([args.cot_text]) + print(f"[extract] text_features shape: {tuple(text_features.shape)}", flush=True) + print("[extract] Step 4/6 — done", flush=True) - print("[PrismAudio] Encoding video with VideoPrism...") + # ------------------------------------------------------------------ + print("[extract] Step 5/6 — encoding video with VideoPrism...", flush=True) global_video_features, video_features, global_text_features = \ feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) + print(f"[extract] global_video_features : {tuple(global_video_features.shape)}", flush=True) + print(f"[extract] video_features : {tuple(video_features.shape)}", flush=True) + print(f"[extract] global_text_features : {tuple(global_text_features.shape)}", flush=True) + print("[extract] Step 5/6 — done", flush=True) - print("[PrismAudio] Encoding video with Synchformer...") + # ------------------------------------------------------------------ + print("[extract] Step 6/6 — encoding video with Synchformer...", flush=True) sync_features = feat_utils.encode_video_with_sync(sync_input) + print(f"[extract] sync_features shape: {tuple(sync_features.shape)}", flush=True) + print("[extract] Step 6/6 — done", flush=True) - # Save as .npz + # ------------------------------------------------------------------ + print(f"[extract] Saving features to {args.output} ...", flush=True) np.savez( args.output, video_features=video_features.cpu().numpy(), @@ -107,7 +131,7 @@ def main(): caption_cot=args.cot_text, duration=duration, ) - print(f"[PrismAudio] Features saved to {args.output}") + print(f"[extract] Done — features saved to {args.output}", flush=True) if __name__ == "__main__":