From 7c54ee84820da655d9416a38017fbecc3e42fb36 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 18:06:10 +0100 Subject: [PATCH] feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env Co-Authored-By: Claude Sonnet 4.6 --- nodes/feature_extractor.py | 102 ++++++++++++++++++++++++++++++++ scripts/environment.yml | 21 +++++++ scripts/extract_features.py | 112 ++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 nodes/feature_extractor.py create mode 100644 scripts/environment.yml create mode 100755 scripts/extract_features.py diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py new file mode 100644 index 0000000..d593743 --- /dev/null +++ b/nodes/feature_extractor.py @@ -0,0 +1,102 @@ +import os +import hashlib +import subprocess +import tempfile +import torch + +from .utils import PRISMAUDIO_CATEGORY +from .feature_loader import PrismAudioFeatureLoader + + +def _hash_inputs(video_tensor, cot_text): + """Create a hash of the inputs for caching.""" + h = hashlib.sha256() + h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed + h.update(cot_text.encode()) + return h.hexdigest()[:16] + + +def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30): + """Save ComfyUI IMAGE tensor [T,H,W,C] to MP4.""" + import torchvision.io as tvio + # ComfyUI IMAGE is [T,H,W,C] float32 [0,1] + frames = (video_tensor * 255).to(torch.uint8) + # torchvision write_video expects [T,H,W,C] uint8 + tvio.write_video(output_path, frames, fps=fps) + + +class PrismAudioFeatureExtractor: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": ("IMAGE",), + "caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}), + }, + "optional": { + "python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}), + "cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}), + "synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}), + }, + } + + RETURN_TYPES = ("PRISMAUDIO_FEATURES",) + RETURN_NAMES = ("features",) + FUNCTION = "extract_features" + CATEGORY = PRISMAUDIO_CATEGORY + + def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""): + # Determine cache directory + if not cache_dir: + cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features") + os.makedirs(cache_dir, exist_ok=True) + + # Check cache + cache_hash = _hash_inputs(video, caption_cot) + cached_path = os.path.join(cache_dir, f"{cache_hash}.npz") + if os.path.exists(cached_path): + print(f"[PrismAudio] Using cached features: {cached_path}") + loader = PrismAudioFeatureLoader() + return loader.load_features(cached_path) + + # Save video to temp file + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: + tmp_video = tmp.name + _save_video_tensor_to_mp4(video, tmp_video) + + # Build subprocess command + script_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "scripts", "extract_features.py" + ) + + cmd = [ + python_env, + script_path, + "--video", tmp_video, + "--cot_text", caption_cot, + "--output", cached_path, + ] + if synchformer_ckpt: + cmd.extend(["--synchformer_ckpt", synchformer_ckpt]) + + print(f"[PrismAudio] Extracting features via subprocess...") + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=600, # 10 minute timeout + ) + if result.returncode != 0: + raise RuntimeError( + f"[PrismAudio] Feature extraction failed:\n{result.stderr}" + ) + print(result.stdout) + finally: + if os.path.exists(tmp_video): + os.unlink(tmp_video) + + # Load the extracted features + loader = PrismAudioFeatureLoader() + return loader.load_features(cached_path) diff --git a/scripts/environment.yml b/scripts/environment.yml new file mode 100644 index 0000000..f098156 --- /dev/null +++ b/scripts/environment.yml @@ -0,0 +1,21 @@ +name: prismaudio-extract +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - ffmpeg<7 + - pip: + - torch>=2.6.0 + - torchaudio>=2.6.0 + - torchvision>=0.21.0 + - tensorflow-cpu==2.15.0 + - jax + - jaxlib + - transformers>=4.52.3 + - decord + - einops>=0.7.0 + - numpy + - mediapy + - git+https://github.com/google-deepmind/videoprism.git diff --git a/scripts/extract_features.py b/scripts/extract_features.py new file mode 100755 index 0000000..142697d --- /dev/null +++ b/scripts/extract_features.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Standalone PrismAudio feature extraction script. +Run in a separate conda env with JAX/TF installed. + +Usage: + python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz + +Setup: + conda env create -f environment.yml + conda activate prismaudio-extract +""" + +import argparse +import os +import sys +import numpy as np +import torch + + +def main(): + parser = argparse.ArgumentParser(description="PrismAudio feature extraction") + parser.add_argument("--video", required=True, help="Path to input video") + parser.add_argument("--cot_text", required=True, help="Chain-of-thought description") + parser.add_argument("--output", required=True, help="Output .npz path") + parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint") + parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON") + parser.add_argument("--clip_fps", type=float, default=4.0) + parser.add_argument("--clip_size", type=int, default=288) + parser.add_argument("--sync_fps", type=float, default=25.0) + parser.add_argument("--sync_size", type=int, default=224) + args = parser.parse_args() + + if not os.path.exists(args.video): + print(f"Error: Video not found: {args.video}") + sys.exit(1) + + # Import feature extraction utils (requires JAX/TF) + from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils + import torchvision.transforms as T + from decord import VideoReader, cpu + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Initialize feature extractor + feat_utils = FeaturesUtils( + vae_config_path=args.vae_config, + synchformer_ckpt=args.synchformer_ckpt, + device=device, + ) + + # Load and preprocess video + vr = VideoReader(args.video, ctx=cpu(0)) + fps = vr.get_avg_fps() + total_frames = len(vr) + duration = total_frames / fps + + # Extract CLIP frames (4fps, 288x288) + clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] + clip_indices = [min(i, total_frames - 1) for i in clip_indices] + clip_frames = vr.get_batch(clip_indices).asnumpy() + + clip_transform = T.Compose([ + T.ToPILImage(), + T.Resize(args.clip_size), + T.CenterCrop(args.clip_size), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) + + # Extract Sync frames (25fps, 224x224) + sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] + sync_indices = [min(i, total_frames - 1) for i in sync_indices] + sync_frames = vr.get_batch(sync_indices).asnumpy() + + sync_transform = T.Compose([ + T.ToPILImage(), + T.Resize(args.sync_size), + T.CenterCrop(args.sync_size), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) + + # Extract features + print("[PrismAudio] Encoding text with T5-Gemma...") + text_features = feat_utils.encode_t5_text([args.cot_text]) + + print("[PrismAudio] Encoding video with VideoPrism...") + global_video_features, video_features, global_text_features = \ + feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) + + print("[PrismAudio] Encoding video with Synchformer...") + sync_features = feat_utils.encode_video_with_sync(sync_input) + + # Save as .npz + np.savez( + args.output, + video_features=video_features.cpu().numpy(), + global_video_features=global_video_features.cpu().numpy(), + text_features=text_features.cpu().numpy(), + global_text_features=global_text_features.cpu().numpy(), + sync_features=sync_features.cpu().numpy(), + caption_cot=args.cot_text, + duration=duration, + ) + print(f"[PrismAudio] Features saved to {args.output}") + + +if __name__ == "__main__": + main()