49 KiB
ComfyUI-PrismAudio Implementation Plan
For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
Goal: Build ComfyUI custom nodes for PrismAudio video-to-audio and text-to-audio generation with adaptive VRAM management and isolated feature extraction.
Architecture: Selective code extraction from PrismAudio prismaudio branch into prismaudio_core/ module. 5 ComfyUI nodes (ModelLoader, FeatureLoader, FeatureExtractor, Sampler, TextOnly). Feature extraction via subprocess bridge to isolated JAX/TF environment. Auto-download from HuggingFace with gated model support.
Tech Stack: PyTorch, ComfyUI APIs (folder_paths, comfy.model_management, comfy.utils), HuggingFace Hub, transformers (T5-Gemma), einops, k-diffusion, safetensors
Bug Fixes Applied (from review)
This plan incorporates fixes for all 14 bugs identified during review:
- sample_discrete_euler callback: Copy function into prismaudio_core, add callback param to the sampling loop
- Metadata format: Return
(dict,)tuple, not flat dict — matchesMultiConditioner.forward(batch_metadata: List[Dict]) - video_exist: Use
torch.tensor(True/False), not Python bool - None features: Use zero tensors of correct shape, never None —
pad_sequence(None)crashes - update_seq_lengths removed: Does not exist in source. Model adapts to input shapes dynamically — no seq length config needed
- Sequence length config: Not needed — model handles variable lengths natively via input tensor shapes
- T5-Gemma class: Use
AutoModelForSeq2SeqLM.get_encoder(), notAutoModel.encoder - Peak normalization: Add
audio / audio.abs().max().clamp(min=1e-8)before clamp - Empty feature substitution: Match reference approach — substitute on raw conditioning output with correct shapes
- hf_token security: Remove STRING widget entirely. Rely on env var /
huggingface-cli loginonly. Document in README - Synchformer size: Corrected to ~950MB in docs
- T5 truncation: Match reference —
truncation=False, no max_length - Remove global_video/text_features from metadata: Not consumed by any conditioner
- Add tqdm to requirements
Bug Fixes Applied (from second review)
- Sync_MLP zero-tensor crash: Sync zero-tensor fallback must be
[8, 768]not[1, 768]— Sync_MLP doeslength // 8which gives 0 for length=1, causingF.interpolateon empty tensor - sample_discrete_euler undefined
i: Loop needsenumerate()—for i, (t_curr, t_prev) in enumerate(zip(...)) - _update_seq_lengths removed entirely: Was a no-op (attributes don't exist on DiT). Model handles variable lengths natively — function deleted
- cot_description removed from Sampler: Was dead code — features already contain pre-computed text_features
- Conditioner VRAM leak: Add
diffusion.conditioner.to(get_offload_device())after generation in offload path - VAE size corrected: ~2.52GB, not ~300MB
Bug Fixes Applied (from third review)
- Remove video_features substitution:
_substitute_empty_featuresshould only substitute sync_features. Reference code checks formetaclip_features(wrong key for prismaudio config), so video substitution never runs. Cond_MLP with zero input + bias-free linears naturally produces near-zero output - Remove dead
sample()andsample_rf: Wrong noise schedule (linear vs cosine), never called for rectified_flow. Only keepsample_discrete_euler - VAE decode in fp32: Keep pretransform in fp32 even when rest of model is fp16/bf16 — snake activations overflow in fp16
- Lazy imports in nodes/__init__.py: Use try/except to allow incremental development
- MPS Generator guard:
torch.Generator(device="cpu")on Apple Silicon, move noise to device after - Use comfy.utils.load_torch_file for VAE: Consistent with diffusion loading, handles PyTorch 2.6+ weights_only default
- Task 10 stale reference: Remove mention of
_update_seq_lengths
Bug Fixes Applied (from fourth review)
- TextOnly missing MPS guard: Fix-on-fix regression — MPS Generator guard was applied to Sampler but not TextOnly
- TextOnly noise dtype: Was passing dtype to torch.randn directly (fp16 noise), now generates fp32 then converts (matching Sampler)
- Sync substitution seq length: Low-severity divergence from reference, accepted (DiT handles variable-length sync_cond)
Task 1: Project Scaffolding
Files:
- Create:
__init__.py - Create:
nodes/__init__.py - Create:
nodes/utils.py - Create:
requirements.txt
Step 1: Create requirements.txt
einops>=0.7.0
safetensors
huggingface_hub
transformers>=4.52.3
k-diffusion>=0.1.1
alias-free-torch
descript-audio-codec
tqdm
Step 2: Create nodes/utils.py with shared helpers
import os
import torch
import folder_paths
import comfy.model_management as mm
PRISMAUDIO_CATEGORY = "PrismAudio"
SAMPLE_RATE = 44100
DOWNSAMPLING_RATIO = 2048
IO_CHANNELS = 64
def get_prismaudio_model_dir():
"""Get or create the prismaudio model directory."""
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
os.makedirs(model_dir, exist_ok=True)
return model_dir
def register_model_folder():
"""Register prismaudio model folder with ComfyUI."""
model_dir = get_prismaudio_model_dir()
folder_paths.add_model_folder_path("prismaudio", model_dir)
def get_device():
return mm.get_torch_device()
def get_offload_device():
return mm.unet_offload_device()
def get_free_memory(device=None):
if device is None:
device = get_device()
return mm.get_free_memory(device)
def soft_empty_cache():
mm.soft_empty_cache()
def determine_precision(preference, device):
"""Determine the best precision for the given device."""
if preference != "auto":
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
if device.type == "cpu":
return torch.float32
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
def determine_offload_strategy(preference):
"""Determine offload strategy based on available VRAM."""
if preference != "auto":
return preference
free_mem = get_free_memory()
gb = free_mem / (1024 ** 3)
if gb >= 24:
return "keep_in_vram"
else:
return "offload_to_cpu"
def try_import_flash_attn():
"""Try to import flash attention, return None if unavailable."""
try:
import flash_attn
return flash_attn
except ImportError:
return None
def resolve_hf_token():
"""Resolve HF token from env var or cached login. No widget — security risk."""
env_token = os.environ.get("HF_TOKEN")
if env_token:
return env_token
# huggingface_hub will use cached token automatically if None is passed
return None
Step 3: Create nodes/__init__.py
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
# Lazy imports — allows incremental development (nodes can be added one at a time)
_NODES = {
"PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"),
"PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"),
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
}
for key, (module_path, class_name, display_name) in _NODES.items():
try:
import importlib
mod = importlib.import_module(module_path, package=__name__)
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
except (ImportError, AttributeError) as e:
print(f"[PrismAudio] Skipping {key}: {e}")
Step 4: Create top-level __init__.py
"""
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
"""
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
Step 5: Commit
git add __init__.py nodes/__init__.py nodes/utils.py requirements.txt
git commit -m "feat: project scaffolding with shared utils and node registration"
Task 2: Extract prismaudio_core — Model Config + Factory
Files:
- Create:
prismaudio_core/__init__.py - Create:
prismaudio_core/configs/prismaudio.json(copy from PrismAudio repo) - Create:
prismaudio_core/factory.py(adapted fromPrismAudio/models/factory.py)
Step 1: Create prismaudio_core/__init__.py
"""
PrismAudio core inference modules.
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
Only inference-critical code — no training, no JAX/TF dependencies.
"""
Step 2: Copy prismaudio.json config
Fetch from https://raw.githubusercontent.com/FunAudioLLM/ThinkSound/prismaudio/PrismAudio/configs/model_configs/prismaudio.json and save to prismaudio_core/configs/prismaudio.json. This is a JSON config file with no code — copy verbatim.
Step 3: Create factory.py
Extract from PrismAudio/models/factory.py. Keep only these functions (remove training-related code):
create_model_from_config(model_config)— entry pointcreate_diffusion_cond_from_config(config)— creates the full modelcreate_pretransform_from_config(pretransform_config, sample_rate)— VAEcreate_autoencoder_from_config(config)— AudioAutoencodercreate_bottleneck_from_config(config)— VAEBottleneckcreate_multi_conditioner_from_conditioning_config(config)— conditioners
All imports should reference prismaudio_core.models.* instead of PrismAudio.models.*.
Step 4: Commit
git add prismaudio_core/
git commit -m "feat: extract prismaudio_core config and model factory"
Task 3: Extract prismaudio_core — Model Modules
Files:
- Create:
prismaudio_core/models/__init__.py - Create:
prismaudio_core/models/dit.py(fromPrismAudio/models/dit.py) - Create:
prismaudio_core/models/diffusion.py(fromPrismAudio/models/diffusion.py) - Create:
prismaudio_core/models/conditioners.py(fromPrismAudio/models/conditioners.py) - Create:
prismaudio_core/models/autoencoders.py(fromPrismAudio/models/autoencoders.py) - Create:
prismaudio_core/models/pretransforms.py(fromPrismAudio/models/pretransforms.py) - Create:
prismaudio_core/models/blocks.py(fromPrismAudio/models/blocks.py) - Create:
prismaudio_core/models/utils.py(fromPrismAudio/models/utils.py) - Create:
prismaudio_core/models/bottleneck.py(fromPrismAudio/models/bottleneck.py) - Create:
prismaudio_core/models/transformer.py(fromPrismAudio/models/transformer.py) - Create:
prismaudio_core/models/local_attention.py(if used by transformer)
Step 1: Extract model files
For each file, extract from the PrismAudio repo. Key modifications:
- Change all internal imports from
PrismAudio.models.*toprismaudio_core.models.* - Remove training-only code (loss functions, training step methods, gradient checkpointing setup)
- Keep all inference paths intact
Critical classes to preserve:
From dit.py:
DiffusionTransformer— full class withforward(), CFG logic, conditioning assemblyFourierFeatures— timestep embedding- Keep
empty_clip_featandempty_sync_featlearned parameters (nn.Parameter, zeros)
From diffusion.py:
ConditionedDiffusionModelWrapper— withget_conditioning_inputs()and routing logicDiTWrapper— thin wrapper that passes all kwargs throughcreate_diffusion_cond_from_config()— factory function
From conditioners.py:
Cond_MLP(type"cond_mlp") — for video_features and text_features. Usespad_sequence, 2-layer MLP, returns[embeddings, ones_mask]. During eval with batch<16, doubles batch with null embed for CFGSync_MLP(type"sync_mlp") — for sync_features with learnablesync_pos_embof shape (1,1,8,dim), reshapes into segments of 8, interpolates to target lengthMultiConditioner— iterates overbatch_metadata: List[Dict], collects per-sample inputs, calls each conditioner. Returns dict of{key: (tensor, mask)}create_multi_conditioner_from_conditioning_config()— factory
From autoencoders.py:
AudioAutoencoder— withencode_audio()anddecode_audio()OobleckEncoder,OobleckDecoder— with ResidualUnit, snake activation- Dependencies:
alias_free_torch(SnakeBeta),dac.nn(WNConv1d, WNConvTranspose1d)
From pretransforms.py:
AutoencoderPretransform— wraps AudioAutoencoder,encode()anddecode()methods
From bottleneck.py:
VAEBottleneck— reparameterization trick (split mean/logvar, sample)
From blocks.py:
- Any shared blocks used by the above (attention blocks, FeedForward, etc.)
From transformer.py:
ContinuousTransformer— the core transformer with cross-attention, used by DiffusionTransformer
From utils.py:
load_ckpt_state_dict()— handles .safetensors and .ckpt, optional prefix strippingremove_weight_norm_from_model()— used in some inference paths
Step 2: Handle flash-attn gracefully in transformer.py / blocks.py
Replace hard import flash_attn with:
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
In the attention forward pass, use:
if HAS_FLASH_ATTN:
out = flash_attn_func(q, k, v, ...)
else:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, ...)
Step 3: Verify imports resolve
Run: python -c "from prismaudio_core.factory import create_model_from_config; print('OK')" from the project root (with ComfyUI's python).
Expected: OK (or import errors to fix iteratively)
Step 4: Commit
git add prismaudio_core/models/
git commit -m "feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)"
Task 4: Extract prismaudio_core — Inference/Sampling (with callback fix)
Files:
- Create:
prismaudio_core/inference/__init__.py - Create:
prismaudio_core/inference/sampling.py(MODIFIED fromPrismAudio/inference/sampling.py) - Create:
prismaudio_core/inference/utils.py(fromPrismAudio/inference/utils.py)
Step 1: Extract sampling.py WITH callback support added
The original sample_discrete_euler uses tqdm and has no callback parameter.
We MUST copy and modify it to add callback support for ComfyUI progress bars.
import torch
from tqdm import trange
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
"""Discrete Euler sampler for rectified flow, with optional callback.
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
Original uses tqdm internally.
Args:
model: The diffusion model (DiTWrapper)
x: Initial noise tensor [B, C, T]
steps: Number of sampling steps
sigma_max: Maximum sigma (default 1.0 for rectified flow)
callback: Optional callable({"i": step, "x": current_x}) for progress
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
sync_cond, cfg_scale, batch_cfg, etc.
"""
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
dt = t_next - t_curr
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
x = x + dt * model(x, t_curr_tensor, **extra_args)
if callback is not None:
callback({"i": i, "x": x})
return x
# Note: sample_rf() and sample() (v-diffusion) are NOT included.
# PrismAudio uses rectified_flow objective which only needs sample_discrete_euler.
# Including unused samplers with potentially wrong math is a maintenance hazard.
Step 2: Extract inference/utils.py
Keep:
set_audio_channels(audio, target_channels)prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device)
Step 3: Verify sampling import
Run: python -c "from prismaudio_core.inference.sampling import sample_discrete_euler; print('OK')"
Expected: OK
Step 4: Commit
git add prismaudio_core/inference/
git commit -m "feat: extract prismaudio_core inference with callback-enabled sampling"
Task 5: PrismAudioModelLoader Node
Files:
- Create:
nodes/model_loader.py
Step 1: Write the node
Key design decisions:
- No hf_token widget (security risk — saved to workflow JSON). Uses env var / cached login only.
- Creates model with default config. Duration-dependent seq lengths handled at sample time.
- The model config's
sample_size: 397312corresponds to ~9s default. For other durations, the Sampler node will update seq lengths on the DiT before each generation.
import os
import json
import torch
import folder_paths
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
get_device, get_offload_device, determine_precision, determine_offload_strategy,
soft_empty_cache, resolve_hf_token,
)
# HuggingFace repo for auto-download
HF_REPO_ID = "FunAudioLLM/PrismAudio"
REQUIRED_FILES = {
"diffusion": "prismaudio.ckpt",
"vae": "vae.ckpt",
"synchformer": "synchformer_state_dict.pth",
}
def _download_if_missing(filename, model_dir, hf_token=None):
"""Download a model file from HuggingFace if not present locally."""
filepath = os.path.join(model_dir, filename)
if os.path.exists(filepath):
return filepath
from huggingface_hub import hf_hub_download
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
try:
downloaded = hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=model_dir,
token=hf_token or None,
)
return downloaded
except Exception as e:
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
raise RuntimeError(
f"[PrismAudio] Model '{filename}' requires license acceptance. "
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
f"then set HF_TOKEN env var or run: huggingface-cli login"
) from e
raise
class PrismAudioModelLoader:
@classmethod
def INPUT_TYPES(cls):
register_model_folder()
return {
"required": {
"precision": (["auto", "fp32", "fp16", "bf16"],),
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
},
}
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = PRISMAUDIO_CATEGORY
def load_model(self, precision, offload_strategy):
device = get_device()
dtype = determine_precision(precision, device)
strategy = determine_offload_strategy(offload_strategy)
token = resolve_hf_token()
model_dir = get_prismaudio_model_dir()
# Auto-download missing files
for key, filename in REQUIRED_FILES.items():
_download_if_missing(filename, model_dir, hf_token=token)
# Load config
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"prismaudio_core", "configs", "prismaudio.json"
)
with open(config_path) as f:
model_config = json.load(f)
# Create model from config
from prismaudio_core.factory import create_model_from_config
model = create_model_from_config(model_config)
# Load diffusion weights
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
if "state_dict" in diffusion_state:
diffusion_state = diffusion_state["state_dict"]
model.load_state_dict(diffusion_state, strict=False)
# Load VAE weights separately
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
vae_full_state = comfy.utils.load_torch_file(vae_path)
# Strip "autoencoder." prefix from keys
vae_state = {}
prefix = "autoencoder."
for k, v in vae_full_state.items():
if k.startswith(prefix):
vae_state[k[len(prefix):]] = v
else:
vae_state[k] = v
model.pretransform.load_state_dict(vae_state)
# Apply precision: DiT + conditioners in user-selected dtype,
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
model.model.to(dtype) # DiTWrapper
model.conditioner.to(dtype) # MultiConditioner
# model.pretransform stays in fp32
if strategy == "keep_in_vram":
model = model.to(device)
else:
model = model.to(get_offload_device())
model.eval()
return ({
"model": model,
"dtype": dtype,
"strategy": strategy,
"config": model_config,
"model_dir": model_dir,
},)
Step 2: Test that ComfyUI discovers the node
Run ComfyUI and check that "PrismAudio Model Loader" appears in the node list.
Step 3: Commit
git add nodes/model_loader.py
git commit -m "feat: PrismAudioModelLoader node with auto-download and adaptive VRAM"
Task 6: PrismAudioFeatureLoader Node
Files:
- Create:
nodes/feature_loader.py
Step 1: Write the node
import os
import numpy as np
import torch
from .utils import PRISMAUDIO_CATEGORY
# Keys consumed by the conditioners (video_features, text_features, sync_features)
# global_video_features and global_text_features are NOT consumed by any conditioner
# in the prismaudio.json config — they are unused.
REQUIRED_KEYS = [
"video_features",
"text_features",
"sync_features",
]
class PrismAudioFeatureLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
RETURN_NAMES = ("features",)
FUNCTION = "load_features"
CATEGORY = PRISMAUDIO_CATEGORY
def load_features(self, npz_path):
if not os.path.exists(npz_path):
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
data = np.load(npz_path, allow_pickle=True)
features = {}
for key in REQUIRED_KEYS:
if key in data:
features[key] = torch.from_numpy(data[key]).float()
else:
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
if key == "sync_features":
features[key] = torch.zeros(8, 768)
else:
features[key] = torch.zeros(1, 1024)
# Load duration if present
if "duration" in data:
features["duration"] = float(data["duration"])
return (features,)
Step 2: Commit
git add nodes/feature_loader.py
git commit -m "feat: PrismAudioFeatureLoader node for pre-computed .npz files"
Task 7: PrismAudioFeatureExtractor Node (Subprocess Bridge)
Files:
- Create:
nodes/feature_extractor.py - Create:
scripts/extract_features.py - Create:
scripts/environment.yml
Step 1: Create the conda environment.yml
name: prismaudio-extract
channels:
- conda-forge
- defaults
dependencies:
- python=3.10
- pip
- ffmpeg<7
- pip:
- torch>=2.6.0
- torchaudio>=2.6.0
- torchvision>=0.21.0
- tensorflow-cpu==2.15.0
- jax
- jaxlib
- transformers>=4.52.3
- decord
- einops>=0.7.0
- numpy
- mediapy
- git+https://github.com/google-deepmind/videoprism.git
Step 2: Create scripts/extract_features.py
This is a standalone script that:
- Takes
--video,--cot_text,--outputarguments - Loads VideoPrism, T5-Gemma, Synchformer
- Extracts features from the video
- Saves as
.npz
#!/usr/bin/env python3
"""
Standalone PrismAudio feature extraction script.
Run in a separate conda env with JAX/TF installed.
Usage:
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
Setup:
conda env create -f environment.yml
conda activate prismaudio-extract
"""
import argparse
import os
import sys
import numpy as np
import torch
def main():
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
parser.add_argument("--video", required=True, help="Path to input video")
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
parser.add_argument("--output", required=True, help="Output .npz path")
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
parser.add_argument("--clip_fps", type=float, default=4.0)
parser.add_argument("--clip_size", type=int, default=288)
parser.add_argument("--sync_fps", type=float, default=25.0)
parser.add_argument("--sync_size", type=int, default=224)
args = parser.parse_args()
if not os.path.exists(args.video):
print(f"Error: Video not found: {args.video}")
sys.exit(1)
# Import feature extraction utils (requires JAX/TF)
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
import torchvision.transforms as T
from decord import VideoReader, cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize feature extractor
feat_utils = FeaturesUtils(
vae_config_path=args.vae_config,
synchformer_ckpt=args.synchformer_ckpt,
device=device,
)
# Load and preprocess video
vr = VideoReader(args.video, ctx=cpu(0))
fps = vr.get_avg_fps()
total_frames = len(vr)
duration = total_frames / fps
# Extract CLIP frames (4fps, 288x288)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = vr.get_batch(clip_indices).asnumpy()
clip_transform = T.Compose([
T.ToPILImage(),
T.Resize(args.clip_size),
T.CenterCrop(args.clip_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
# Extract Sync frames (25fps, 224x224)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = vr.get_batch(sync_indices).asnumpy()
sync_transform = T.Compose([
T.ToPILImage(),
T.Resize(args.sync_size),
T.CenterCrop(args.sync_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
# Extract features
print("[PrismAudio] Encoding text with T5-Gemma...")
text_features = feat_utils.encode_t5_text([args.cot_text])
print("[PrismAudio] Encoding video with VideoPrism...")
global_video_features, video_features, global_text_features = \
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
print("[PrismAudio] Encoding video with Synchformer...")
sync_features = feat_utils.encode_video_with_sync(sync_input)
# Save as .npz
np.savez(
args.output,
video_features=video_features.cpu().numpy(),
global_video_features=global_video_features.cpu().numpy(),
text_features=text_features.cpu().numpy(),
global_text_features=global_text_features.cpu().numpy(),
sync_features=sync_features.cpu().numpy(),
caption_cot=args.cot_text,
duration=duration,
)
print(f"[PrismAudio] Features saved to {args.output}")
if __name__ == "__main__":
main()
Step 3: Create the feature extractor node
import os
import hashlib
import subprocess
import tempfile
import torch
from .utils import PRISMAUDIO_CATEGORY
from .feature_loader import PrismAudioFeatureLoader
def _hash_inputs(video_tensor, cot_text):
"""Create a hash of the inputs for caching."""
h = hashlib.sha256()
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
h.update(cot_text.encode())
return h.hexdigest()[:16]
def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30):
"""Save ComfyUI IMAGE tensor [T,H,W,C] to MP4."""
import torchvision.io as tvio
# ComfyUI IMAGE is [T,H,W,C] float32 [0,1]
frames = (video_tensor * 255).to(torch.uint8)
# torchvision write_video expects [T,H,W,C] uint8
tvio.write_video(output_path, frames, fps=fps)
class PrismAudioFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": ("IMAGE",),
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
},
"optional": {
"python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}),
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
"synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
RETURN_NAMES = ("features",)
FUNCTION = "extract_features"
CATEGORY = PRISMAUDIO_CATEGORY
def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""):
# Determine cache directory
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
os.makedirs(cache_dir, exist_ok=True)
# Check cache
cache_hash = _hash_inputs(video, caption_cot)
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
if os.path.exists(cached_path):
print(f"[PrismAudio] Using cached features: {cached_path}")
loader = PrismAudioFeatureLoader()
return loader.load_features(cached_path)
# Save video to temp file
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp_video = tmp.name
_save_video_tensor_to_mp4(video, tmp_video)
# Build subprocess command
script_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"scripts", "extract_features.py"
)
cmd = [
python_env,
script_path,
"--video", tmp_video,
"--cot_text", caption_cot,
"--output", cached_path,
]
if synchformer_ckpt:
cmd.extend(["--synchformer_ckpt", synchformer_ckpt])
print(f"[PrismAudio] Extracting features via subprocess...")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=600, # 10 minute timeout
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Feature extraction failed:\n{result.stderr}"
)
print(result.stdout)
finally:
if os.path.exists(tmp_video):
os.unlink(tmp_video)
# Load the extracted features
loader = PrismAudioFeatureLoader()
return loader.load_features(cached_path)
Step 4: Commit
git add nodes/feature_extractor.py scripts/extract_features.py scripts/environment.yml
git commit -m "feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env"
Task 8: PrismAudioSampler Node
Files:
- Create:
nodes/sampler.py
Step 1: Write the sampler node
This is the core node. Key fixes from review:
- Metadata is a TUPLE of dicts, not a flat dict
- video_exist is torch.tensor, not Python bool
- Empty features are zero tensors, not None
- Peak normalization before clamp
- Sequence lengths set on DiT config before sampling (matching predict.py approach)
- No callback kwarg forwarded to model — callback is handled by our modified sample_discrete_euler
import torch
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
get_device, get_offload_device, soft_empty_cache,
)
class PrismAudioSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"features": ("PRISMAUDIO_FEATURES",),
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds"}),
"steps": ("INT", {"default": 24, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, features, duration, steps, cfg_scale, seed):
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
diffusion = model["model"]
# Compute latent dimensions
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
# Note: no seq length config needed — the model adapts to input tensor shapes
# dynamically via its transformer architecture.
# Determine if video features are present (not all zeros)
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
# Build metadata as a TUPLE of dicts (one per batch sample)
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
sample_meta = {
"video_features": features["video_features"].to(device, dtype=dtype),
"text_features": features["text_features"].to(device, dtype=dtype),
"sync_features": features["sync_features"].to(device, dtype=dtype),
"video_exist": torch.tensor(has_video),
}
metadata = (sample_meta,)
# Move model to device if offloaded
if strategy == "offload_to_cpu":
diffusion.model.to(device)
diffusion.conditioner.to(device)
soft_empty_cache()
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
# Run conditioning
conditioning = diffusion.conditioner(metadata, device)
# Handle missing video: substitute learned empty embeddings
if not has_video:
_substitute_empty_features(diffusion, conditioning, device, dtype)
# Assemble conditioning inputs for the DiT
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Generate noise from seed (MPS doesn't support torch.Generator)
gen_device = "cpu" if device.type == "mps" else device
generator = torch.Generator(device=gen_device).manual_seed(seed)
noise = torch.randn(
[1, IO_CHANNELS, latent_length],
generator=generator,
device=gen_device,
).to(device=device, dtype=dtype)
# Sample with progress bar
pbar = comfy.utils.ProgressBar(steps)
from prismaudio_core.inference.sampling import sample_discrete_euler
def on_step(info):
pbar.update(1)
fakes = sample_discrete_euler(
diffusion.model,
noise,
steps,
callback=on_step,
**cond_inputs,
cfg_scale=cfg_scale,
batch_cfg=True,
)
# Offload diffusion model and conditioner before VAE decode
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
soft_empty_cache()
diffusion.pretransform.to(device)
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes.float())
# Offload VAE
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
# Peak normalize then clamp (matching reference: div by max abs before clamp)
audio = audio.float()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
def _substitute_empty_features(diffusion, conditioning, device, dtype):
"""Replace sync conditioning with learned empty embedding when video is absent.
Only substitutes sync_features — NOT video_features. The reference code
(predict.py/app.py) checks for 'metaclip_features' which doesn't exist in the
prismaudio.json config, so video substitution never runs. Cond_MLP with zero
input + bias-free linear layers naturally produces near-zero output.
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
"""
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
# Only substitute sync_features (matching reference behavior for prismaudio config)
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
empty = dit.empty_sync_feat.to(device, dtype=dtype)
cond_tensor = conditioning['sync_features'][0]
batch_size = cond_tensor.shape[0]
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1)
conditioning['sync_features'][0] = empty_expanded
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
Step 2: Verify the node registers
Start ComfyUI, check "PrismAudio Sampler" appears in add-node menu.
Step 3: Commit
git add nodes/sampler.py
git commit -m "feat: PrismAudioSampler node with correct metadata format and peak normalization"
Task 9: PrismAudioTextOnly Node
Files:
- Create:
nodes/text_only.py
Step 1: Write the text-only node
Key fixes from review:
- Uses
AutoModelForSeq2SeqLM.get_encoder(), notAutoModel.encoder - No truncation (matching reference)
- Metadata is tuple of dicts with torch.tensor(False) for video_exist
- Zero tensors for video/sync features, not None
- Peak normalization
import torch
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
)
from .sampler import _substitute_empty_features
class PrismAudioTextOnly:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}),
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
"steps": ("INT", {"default": 24, "min": 1, "max": 100}),
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
diffusion = model["model"]
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
# Encode text with T5-Gemma
text_features = _encode_text_t5(text_prompt, device, dtype)
# Build metadata: tuple of one dict per sample
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
# These will be substituted with learned empty embeddings after conditioning
sample_meta = {
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
"text_features": text_features.to(device, dtype=dtype),
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
"video_exist": torch.tensor(False),
}
metadata = (sample_meta,)
if strategy == "offload_to_cpu":
diffusion.model.to(device)
diffusion.conditioner.to(device)
soft_empty_cache()
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
conditioning = diffusion.conditioner(metadata, device)
# Substitute empty features for video/sync
_substitute_empty_features(diffusion, conditioning, device, dtype)
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Generate noise from seed (MPS doesn't support torch.Generator)
gen_device = "cpu" if device.type == "mps" else device
generator = torch.Generator(device=gen_device).manual_seed(seed)
noise = torch.randn(
[1, IO_CHANNELS, latent_length],
generator=generator,
device=gen_device,
).to(device=device, dtype=dtype)
pbar = comfy.utils.ProgressBar(steps)
from prismaudio_core.inference.sampling import sample_discrete_euler
def on_step(info):
pbar.update(1)
fakes = sample_discrete_euler(
diffusion.model,
noise,
steps,
callback=on_step,
**cond_inputs,
cfg_scale=cfg_scale,
batch_cfg=True,
)
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
soft_empty_cache()
diffusion.pretransform.to(device)
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes.float())
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
# Peak normalize then clamp
audio = audio.float()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
# T5-Gemma encoder singleton
_t5_model = None
_t5_tokenizer = None
def _encode_text_t5(text, device, dtype):
"""Encode text using T5-Gemma.
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
FeaturesUtils.encode_t5_text() implementation.
No truncation applied (matching reference behavior).
"""
global _t5_model, _t5_tokenizer
if _t5_model is None:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id = "google/t5gemma-l-l-ul2-it"
token = resolve_hf_token()
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
_t5_model.eval()
_t5_model.to(device, dtype=dtype)
tokens = _t5_tokenizer(
text,
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
outputs = _t5_model(**tokens)
# Move T5 off GPU after encoding to save VRAM
_t5_model.to("cpu")
soft_empty_cache()
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
Step 2: Commit
git add nodes/text_only.py
git commit -m "feat: PrismAudioTextOnly node with correct T5-Gemma encoding"
Task 10: Integration Testing & Polish
Files:
- Modify:
nodes/__init__.py(verify all imports work) - Modify:
__init__.py(verify top-level registration)
Step 1: Verify all node imports resolve
Run from ComfyUI's Python:
cd /path/to/ComfyUI
python -c "
import sys
sys.path.insert(0, 'custom_nodes/ComfyUI-PrismAudio')
from nodes import NODE_CLASS_MAPPINGS
print('Registered nodes:', list(NODE_CLASS_MAPPINGS.keys()))
"
Expected output:
Registered nodes: ['PrismAudioModelLoader', 'PrismAudioFeatureLoader', 'PrismAudioFeatureExtractor', 'PrismAudioSampler', 'PrismAudioTextOnly']
Step 2: Fix any import errors iteratively
Common issues:
prismaudio_coreinternal imports may reference wrong module paths- Missing model submodules in
prismaudio_core/models/ - flash-attn fallback not properly guarded
Step 3: Test model loading (requires GPU + model files)
python -c "
from prismaudio_core.factory import create_model_from_config
import json
with open('prismaudio_core/configs/prismaudio.json') as f:
config = json.load(f)
model = create_model_from_config(config)
print('Model created, params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')
"
Expected: Model created, params: ~518 M
Step 4: End-to-end test with pre-computed features
If you have a .npz feature file from the PrismAudio repo's demo data, test the full pipeline in ComfyUI:
- PrismAudioModelLoader -> PrismAudioFeatureLoader -> PrismAudioSampler -> Preview Audio node
Step 5: Verify variable duration handling
Test with multiple durations (5s, 10s, 20s) to ensure the model adapts to different input shapes and produces audio of the expected length.
Step 6: Commit
git add -A
git commit -m "feat: integration fixes and verification"
Task 11: README
Files:
- Create:
README.md
Step 1: Write README covering:
- What PrismAudio is (brief, link to paper)
- Installation (clone, pip install requirements, optional extraction env setup)
- Node descriptions with input/output tables
- Example workflows (quality path with FeatureExtractor, quick path with FeatureLoader, text-only)
- HuggingFace authentication (2 methods:
HF_TOKENenv var,huggingface-cli login)- Note: hf_token is NOT a node widget for security reasons
- Which models may be gated (T5-Gemma, potentially Stable Audio VAE)
- Model file sizes: diffusion ~2.7GB, VAE ~2.5GB, synchformer ~950MB
- Extraction env setup via conda environment.yml
- Troubleshooting (VRAM, flash-attn optional, gated models)
- Credits and license
Step 2: Commit
git add README.md
git commit -m "docs: README with installation and usage instructions"
Dependency Graph
Task 1 (scaffolding)
├── Task 2 (core config + factory) ──┐
│ └── Task 3 (core models) ──────┤
│ └── Task 4 (core sampling)┤
│ ├── Task 5 (ModelLoader node)
│ ├── Task 6 (FeatureLoader node)
│ ├── Task 7 (FeatureExtractor node)
│ ├── Task 8 (Sampler node)
│ └── Task 9 (TextOnly node)
└────────────────────────────────────────── Task 10 (Integration)
└── Task 11 (README)
Tasks 5-9 can be parallelized after Task 4 is complete. Task 3 is the heaviest — it involves extracting and adapting ~10 model files.