feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
from .feature_loader import PrismAudioFeatureLoader
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_inputs(video_tensor, cot_text):
|
||||||
|
"""Create a hash of the inputs for caching."""
|
||||||
|
h = hashlib.sha256()
|
||||||
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||||
|
h.update(cot_text.encode())
|
||||||
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30):
|
||||||
|
"""Save ComfyUI IMAGE tensor [T,H,W,C] to MP4."""
|
||||||
|
import torchvision.io as tvio
|
||||||
|
# ComfyUI IMAGE is [T,H,W,C] float32 [0,1]
|
||||||
|
frames = (video_tensor * 255).to(torch.uint8)
|
||||||
|
# torchvision write_video expects [T,H,W,C] uint8
|
||||||
|
tvio.write_video(output_path, frames, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioFeatureExtractor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": ("IMAGE",),
|
||||||
|
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}),
|
||||||
|
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
||||||
|
"synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
||||||
|
RETURN_NAMES = ("features",)
|
||||||
|
FUNCTION = "extract_features"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""):
|
||||||
|
# Determine cache directory
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_hash = _hash_inputs(video, caption_cot)
|
||||||
|
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||||
|
if os.path.exists(cached_path):
|
||||||
|
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||||
|
loader = PrismAudioFeatureLoader()
|
||||||
|
return loader.load_features(cached_path)
|
||||||
|
|
||||||
|
# Save video to temp file
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
||||||
|
tmp_video = tmp.name
|
||||||
|
_save_video_tensor_to_mp4(video, tmp_video)
|
||||||
|
|
||||||
|
# Build subprocess command
|
||||||
|
script_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"scripts", "extract_features.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
python_env,
|
||||||
|
script_path,
|
||||||
|
"--video", tmp_video,
|
||||||
|
"--cot_text", caption_cot,
|
||||||
|
"--output", cached_path,
|
||||||
|
]
|
||||||
|
if synchformer_ckpt:
|
||||||
|
cmd.extend(["--synchformer_ckpt", synchformer_ckpt])
|
||||||
|
|
||||||
|
print(f"[PrismAudio] Extracting features via subprocess...")
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=600, # 10 minute timeout
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[PrismAudio] Feature extraction failed:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
print(result.stdout)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_video):
|
||||||
|
os.unlink(tmp_video)
|
||||||
|
|
||||||
|
# Load the extracted features
|
||||||
|
loader = PrismAudioFeatureLoader()
|
||||||
|
return loader.load_features(cached_path)
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
name: prismaudio-extract
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.10
|
||||||
|
- pip
|
||||||
|
- ffmpeg<7
|
||||||
|
- pip:
|
||||||
|
- torch>=2.6.0
|
||||||
|
- torchaudio>=2.6.0
|
||||||
|
- torchvision>=0.21.0
|
||||||
|
- tensorflow-cpu==2.15.0
|
||||||
|
- jax
|
||||||
|
- jaxlib
|
||||||
|
- transformers>=4.52.3
|
||||||
|
- decord
|
||||||
|
- einops>=0.7.0
|
||||||
|
- numpy
|
||||||
|
- mediapy
|
||||||
|
- git+https://github.com/google-deepmind/videoprism.git
|
||||||
Executable
+112
@@ -0,0 +1,112 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Standalone PrismAudio feature extraction script.
|
||||||
|
Run in a separate conda env with JAX/TF installed.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
|
||||||
|
|
||||||
|
Setup:
|
||||||
|
conda env create -f environment.yml
|
||||||
|
conda activate prismaudio-extract
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
|
||||||
|
parser.add_argument("--video", required=True, help="Path to input video")
|
||||||
|
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
|
||||||
|
parser.add_argument("--output", required=True, help="Output .npz path")
|
||||||
|
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
||||||
|
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
||||||
|
parser.add_argument("--clip_fps", type=float, default=4.0)
|
||||||
|
parser.add_argument("--clip_size", type=int, default=288)
|
||||||
|
parser.add_argument("--sync_fps", type=float, default=25.0)
|
||||||
|
parser.add_argument("--sync_size", type=int, default=224)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.video):
|
||||||
|
print(f"Error: Video not found: {args.video}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Import feature extraction utils (requires JAX/TF)
|
||||||
|
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Initialize feature extractor
|
||||||
|
feat_utils = FeaturesUtils(
|
||||||
|
vae_config_path=args.vae_config,
|
||||||
|
synchformer_ckpt=args.synchformer_ckpt,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load and preprocess video
|
||||||
|
vr = VideoReader(args.video, ctx=cpu(0))
|
||||||
|
fps = vr.get_avg_fps()
|
||||||
|
total_frames = len(vr)
|
||||||
|
duration = total_frames / fps
|
||||||
|
|
||||||
|
# Extract CLIP frames (4fps, 288x288)
|
||||||
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||||
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
|
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||||
|
|
||||||
|
clip_transform = T.Compose([
|
||||||
|
T.ToPILImage(),
|
||||||
|
T.Resize(args.clip_size),
|
||||||
|
T.CenterCrop(args.clip_size),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Extract Sync frames (25fps, 224x224)
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||||
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
|
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||||
|
|
||||||
|
sync_transform = T.Compose([
|
||||||
|
T.ToPILImage(),
|
||||||
|
T.Resize(args.sync_size),
|
||||||
|
T.CenterCrop(args.sync_size),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Extract features
|
||||||
|
print("[PrismAudio] Encoding text with T5-Gemma...")
|
||||||
|
text_features = feat_utils.encode_t5_text([args.cot_text])
|
||||||
|
|
||||||
|
print("[PrismAudio] Encoding video with VideoPrism...")
|
||||||
|
global_video_features, video_features, global_text_features = \
|
||||||
|
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
|
||||||
|
|
||||||
|
print("[PrismAudio] Encoding video with Synchformer...")
|
||||||
|
sync_features = feat_utils.encode_video_with_sync(sync_input)
|
||||||
|
|
||||||
|
# Save as .npz
|
||||||
|
np.savez(
|
||||||
|
args.output,
|
||||||
|
video_features=video_features.cpu().numpy(),
|
||||||
|
global_video_features=global_video_features.cpu().numpy(),
|
||||||
|
text_features=text_features.cpu().numpy(),
|
||||||
|
global_text_features=global_text_features.cpu().numpy(),
|
||||||
|
sync_features=sync_features.cpu().numpy(),
|
||||||
|
caption_cot=args.cot_text,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
print(f"[PrismAudio] Features saved to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user