Files
ComfyUI-SelVA/nodes/feature_extractor.py
T
Ethanfel 06f8dbbab4 feat: add hf_token input and HF_TOKEN env forwarding to feature extractor
google/t5gemma-l-l-ul2-it is a gated HuggingFace model requiring auth.
Add optional hf_token input on the node; forward it (plus the legacy
HUGGING_FACE_HUB_TOKEN alias) to the subprocess env. Falls back to
HF_TOKEN from the host environment. Warn clearly when neither is set.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 20:27:33 +01:00

208 lines
8.2 KiB
Python

import os
import sys
import hashlib
import subprocess
import tempfile
import torch
from .utils import PRISMAUDIO_CATEGORY
from .feature_loader import PrismAudioFeatureLoader
# Managed venv created automatically when python_env is left as default
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
_EXTRACT_PACKAGES = [
"torch", "torchaudio", "torchvision",
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
"tensorflow-cpu>=2.16.0",
"jax[cpu]", "jaxlib",
"transformers", "decord", "einops", "numpy", "mediapy",
"git+https://github.com/google-deepmind/videoprism.git",
]
def _pip_install(pip, *packages, label=None):
"""Install one or more packages with visible output; raise on failure."""
tag = label or packages[0]
print(f"[PrismAudio] installing {tag} ...", flush=True)
result = subprocess.run(
[pip, "install", "--progress-bar", "on"] + list(packages),
capture_output=False,
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
"See pip output above for details."
)
print(f"[PrismAudio] {tag} OK", flush=True)
def _ensure_extract_env():
"""Create and populate the managed venv on first use."""
if os.path.exists(_MANAGED_PYTHON):
return _MANAGED_PYTHON
import shutil
if os.path.exists(_MANAGED_VENV):
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
shutil.rmtree(_MANAGED_VENV)
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
print("[PrismAudio] Upgrading pip...", flush=True)
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
total = len(_EXTRACT_PACKAGES)
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
_pip_install(pip, pkg, label=label)
print("[PrismAudio] Feature-extraction env ready.", flush=True)
return _MANAGED_PYTHON
def _hash_inputs(video_tensor, cot_text):
"""Create a hash of the inputs for caching."""
h = hashlib.sha256()
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
h.update(cot_text.encode())
return h.hexdigest()[:16]
def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30):
"""Save ComfyUI IMAGE tensor [T,H,W,C] to MP4 via PIL + ffmpeg.
torchvision.io.write_video requires the optional 'av' (PyAV) package
which is not installed in most ComfyUI environments. ffmpeg is always
available in ComfyUI Docker images.
"""
from PIL import Image
import shutil
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
frame_dir = output_path + "_frames"
os.makedirs(frame_dir, exist_ok=True)
try:
for i, frame in enumerate(frames_np):
Image.fromarray(frame).save(os.path.join(frame_dir, f"{i:06d}.png"))
result = subprocess.run(
[
"ffmpeg", "-y",
"-framerate", str(fps),
"-i", os.path.join(frame_dir, "%06d.png"),
"-c:v", "libx264", "-pix_fmt", "yuv420p",
output_path,
],
capture_output=True, text=True,
)
if result.returncode != 0:
raise RuntimeError(f"[PrismAudio] ffmpeg failed:\n{result.stderr}")
finally:
shutil.rmtree(frame_dir, ignore_errors=True)
class PrismAudioFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": ("IMAGE",),
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
},
"optional": {
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Match this to your source (e.g. 24, 25, 30, 60). Affects temporal sampling in feature extraction."}),
"python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF. Leave as 'python' to auto-install a managed venv on first use."}),
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
"synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}),
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
RETURN_NAMES = ("features",)
FUNCTION = "extract_features"
CATEGORY = PRISMAUDIO_CATEGORY
def extract_features(self, video, caption_cot, fps=30.0, python_env="python", cache_dir="", synchformer_ckpt="", hf_token=""):
# Resolve python binary — auto-install managed venv if empty or default
if not python_env.strip() or python_env.strip() == "python":
python_env = _ensure_extract_env()
# Determine cache directory
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
os.makedirs(cache_dir, exist_ok=True)
# Check cache
cache_hash = _hash_inputs(video, caption_cot)
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
if os.path.exists(cached_path):
print(f"[PrismAudio] Using cached features: {cached_path}")
loader = PrismAudioFeatureLoader()
return loader.load_features(cached_path)
# Save video to temp file
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp_video = tmp.name
_save_video_tensor_to_mp4(video, tmp_video, fps=fps)
# Build subprocess command
script_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"scripts", "extract_features.py"
)
cmd = [
python_env,
script_path,
"--video", tmp_video,
"--cot_text", caption_cot,
"--output", cached_path,
]
if synchformer_ckpt:
cmd.extend(["--synchformer_ckpt", synchformer_ckpt])
# Build env: inherit current env, inject HF token if provided
import copy
env = copy.copy(os.environ)
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
if token:
env["HF_TOKEN"] = token
env["HUGGING_FACE_HUB_TOKEN"] = token
else:
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
try:
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
result = subprocess.run(
cmd,
capture_output=False,
timeout=600, # 10 minute timeout
env=env,
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
"See output above for details."
)
print("[PrismAudio] Feature extraction subprocess finished successfully.")
finally:
if os.path.exists(tmp_video):
os.unlink(tmp_video)
# Load the extracted features
loader = PrismAudioFeatureLoader()
return loader.load_features(cached_path)