# ComfyUI-PrismAudio Implementation Plan > **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. **Goal:** Build ComfyUI custom nodes for PrismAudio video-to-audio and text-to-audio generation with adaptive VRAM management and isolated feature extraction. **Architecture:** Selective code extraction from PrismAudio `prismaudio` branch into `prismaudio_core/` module. 5 ComfyUI nodes (ModelLoader, FeatureLoader, FeatureExtractor, Sampler, TextOnly). Feature extraction via subprocess bridge to isolated JAX/TF environment. Auto-download from HuggingFace with gated model support. **Tech Stack:** PyTorch, ComfyUI APIs (folder_paths, comfy.model_management, comfy.utils), HuggingFace Hub, transformers (T5-Gemma), einops, k-diffusion, safetensors --- ## Bug Fixes Applied (from review) This plan incorporates fixes for all 14 bugs identified during review: 1. **sample_discrete_euler callback**: Copy function into prismaudio_core, add callback param to the sampling loop 2. **Metadata format**: Return `(dict,)` tuple, not flat dict — matches `MultiConditioner.forward(batch_metadata: List[Dict])` 3. **video_exist**: Use `torch.tensor(True/False)`, not Python bool 4. **None features**: Use zero tensors of correct shape, never None — `pad_sequence(None)` crashes 5. **update_seq_lengths removed**: Does not exist in source. Model adapts to input shapes dynamically — no seq length config needed 6. **Sequence length config**: Not needed — model handles variable lengths natively via input tensor shapes 7. **T5-Gemma class**: Use `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder` 8. **Peak normalization**: Add `audio / audio.abs().max().clamp(min=1e-8)` before clamp 9. **Empty feature substitution**: Match reference approach — substitute on raw conditioning output with correct shapes 10. **hf_token security**: Remove STRING widget entirely. Rely on env var / `huggingface-cli login` only. Document in README 11. **Synchformer size**: Corrected to ~950MB in docs 12. **T5 truncation**: Match reference — `truncation=False`, no max_length 13. **Remove global_video/text_features from metadata**: Not consumed by any conditioner 14. **Add tqdm to requirements** ### Bug Fixes Applied (from second review) 15. **Sync_MLP zero-tensor crash**: Sync zero-tensor fallback must be `[8, 768]` not `[1, 768]` — Sync_MLP does `length // 8` which gives 0 for length=1, causing `F.interpolate` on empty tensor 16. **sample_discrete_euler undefined `i`**: Loop needs `enumerate()` — `for i, (t_curr, t_prev) in enumerate(zip(...))` 17. **_update_seq_lengths removed entirely**: Was a no-op (attributes don't exist on DiT). Model handles variable lengths natively — function deleted 18. **cot_description removed from Sampler**: Was dead code — features already contain pre-computed text_features 19. **Conditioner VRAM leak**: Add `diffusion.conditioner.to(get_offload_device())` after generation in offload path 20. **VAE size corrected**: ~2.52GB, not ~300MB ### Bug Fixes Applied (from third review) 21. **Remove video_features substitution**: `_substitute_empty_features` should only substitute sync_features. Reference code checks for `metaclip_features` (wrong key for prismaudio config), so video substitution never runs. Cond_MLP with zero input + bias-free linears naturally produces near-zero output 22. **Remove dead `sample()` and `sample_rf`**: Wrong noise schedule (linear vs cosine), never called for rectified_flow. Only keep `sample_discrete_euler` 23. **VAE decode in fp32**: Keep pretransform in fp32 even when rest of model is fp16/bf16 — snake activations overflow in fp16 24. **Lazy imports in nodes/__init__.py**: Use try/except to allow incremental development 25. **MPS Generator guard**: `torch.Generator(device="cpu")` on Apple Silicon, move noise to device after 26. **Use comfy.utils.load_torch_file for VAE**: Consistent with diffusion loading, handles PyTorch 2.6+ weights_only default 27. **Task 10 stale reference**: Remove mention of `_update_seq_lengths` ### Bug Fixes Applied (from fourth review) 28. **TextOnly missing MPS guard**: Fix-on-fix regression — MPS Generator guard was applied to Sampler but not TextOnly 29. **TextOnly noise dtype**: Was passing dtype to torch.randn directly (fp16 noise), now generates fp32 then converts (matching Sampler) 30. **Sync substitution seq length**: Low-severity divergence from reference, accepted (DiT handles variable-length sync_cond) --- ### Task 1: Project Scaffolding **Files:** - Create: `__init__.py` - Create: `nodes/__init__.py` - Create: `nodes/utils.py` - Create: `requirements.txt` **Step 1: Create requirements.txt** ``` einops>=0.7.0 safetensors huggingface_hub transformers>=4.52.3 k-diffusion>=0.1.1 alias-free-torch descript-audio-codec tqdm ``` **Step 2: Create nodes/utils.py with shared helpers** ```python import os import torch import folder_paths import comfy.model_management as mm PRISMAUDIO_CATEGORY = "PrismAudio" SAMPLE_RATE = 44100 DOWNSAMPLING_RATIO = 2048 IO_CHANNELS = 64 def get_prismaudio_model_dir(): """Get or create the prismaudio model directory.""" model_dir = os.path.join(folder_paths.models_dir, "prismaudio") os.makedirs(model_dir, exist_ok=True) return model_dir def register_model_folder(): """Register prismaudio model folder with ComfyUI.""" model_dir = get_prismaudio_model_dir() folder_paths.add_model_folder_path("prismaudio", model_dir) def get_device(): return mm.get_torch_device() def get_offload_device(): return mm.unet_offload_device() def get_free_memory(device=None): if device is None: device = get_device() return mm.get_free_memory(device) def soft_empty_cache(): mm.soft_empty_cache() def determine_precision(preference, device): """Determine the best precision for the given device.""" if preference != "auto": return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference] if device.type == "cpu": return torch.float32 if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 def determine_offload_strategy(preference): """Determine offload strategy based on available VRAM.""" if preference != "auto": return preference free_mem = get_free_memory() gb = free_mem / (1024 ** 3) if gb >= 24: return "keep_in_vram" else: return "offload_to_cpu" def try_import_flash_attn(): """Try to import flash attention, return None if unavailable.""" try: import flash_attn return flash_attn except ImportError: return None def resolve_hf_token(): """Resolve HF token from env var or cached login. No widget — security risk.""" env_token = os.environ.get("HF_TOKEN") if env_token: return env_token # huggingface_hub will use cached token automatically if None is passed return None ``` **Step 3: Create nodes/__init__.py** ```python NODE_CLASS_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {} # Lazy imports — allows incremental development (nodes can be added one at a time) _NODES = { "PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"), "PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"), "PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"), "PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"), "PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"), } for key, (module_path, class_name, display_name) in _NODES.items(): try: import importlib mod = importlib.import_module(module_path, package=__name__) NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name) NODE_DISPLAY_NAME_MAPPINGS[key] = display_name except (ImportError, AttributeError) as e: print(f"[PrismAudio] Skipping {key}: {e}") ``` **Step 4: Create top-level __init__.py** ```python """ ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026). """ from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] ``` **Step 5: Commit** ```bash git add __init__.py nodes/__init__.py nodes/utils.py requirements.txt git commit -m "feat: project scaffolding with shared utils and node registration" ``` --- ### Task 2: Extract prismaudio_core — Model Config + Factory **Files:** - Create: `prismaudio_core/__init__.py` - Create: `prismaudio_core/configs/prismaudio.json` (copy from PrismAudio repo) - Create: `prismaudio_core/factory.py` (adapted from `PrismAudio/models/factory.py`) **Step 1: Create prismaudio_core/__init__.py** ```python """ PrismAudio core inference modules. Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch). Only inference-critical code — no training, no JAX/TF dependencies. """ ``` **Step 2: Copy prismaudio.json config** Fetch from `https://raw.githubusercontent.com/FunAudioLLM/ThinkSound/prismaudio/PrismAudio/configs/model_configs/prismaudio.json` and save to `prismaudio_core/configs/prismaudio.json`. This is a JSON config file with no code — copy verbatim. **Step 3: Create factory.py** Extract from `PrismAudio/models/factory.py`. Keep only these functions (remove training-related code): - `create_model_from_config(model_config)` — entry point - `create_diffusion_cond_from_config(config)` — creates the full model - `create_pretransform_from_config(pretransform_config, sample_rate)` — VAE - `create_autoencoder_from_config(config)` — AudioAutoencoder - `create_bottleneck_from_config(config)` — VAEBottleneck - `create_multi_conditioner_from_conditioning_config(config)` — conditioners All imports should reference `prismaudio_core.models.*` instead of `PrismAudio.models.*`. **Step 4: Commit** ```bash git add prismaudio_core/ git commit -m "feat: extract prismaudio_core config and model factory" ``` --- ### Task 3: Extract prismaudio_core — Model Modules **Files:** - Create: `prismaudio_core/models/__init__.py` - Create: `prismaudio_core/models/dit.py` (from `PrismAudio/models/dit.py`) - Create: `prismaudio_core/models/diffusion.py` (from `PrismAudio/models/diffusion.py`) - Create: `prismaudio_core/models/conditioners.py` (from `PrismAudio/models/conditioners.py`) - Create: `prismaudio_core/models/autoencoders.py` (from `PrismAudio/models/autoencoders.py`) - Create: `prismaudio_core/models/pretransforms.py` (from `PrismAudio/models/pretransforms.py`) - Create: `prismaudio_core/models/blocks.py` (from `PrismAudio/models/blocks.py`) - Create: `prismaudio_core/models/utils.py` (from `PrismAudio/models/utils.py`) - Create: `prismaudio_core/models/bottleneck.py` (from `PrismAudio/models/bottleneck.py`) - Create: `prismaudio_core/models/transformer.py` (from `PrismAudio/models/transformer.py`) - Create: `prismaudio_core/models/local_attention.py` (if used by transformer) **Step 1: Extract model files** For each file, extract from the PrismAudio repo. Key modifications: - Change all internal imports from `PrismAudio.models.*` to `prismaudio_core.models.*` - Remove training-only code (loss functions, training step methods, gradient checkpointing setup) - Keep all inference paths intact **Critical classes to preserve:** From `dit.py`: - `DiffusionTransformer` — full class with `forward()`, CFG logic, conditioning assembly - `FourierFeatures` — timestep embedding - Keep `empty_clip_feat` and `empty_sync_feat` learned parameters (nn.Parameter, zeros) From `diffusion.py`: - `ConditionedDiffusionModelWrapper` — with `get_conditioning_inputs()` and routing logic - `DiTWrapper` — thin wrapper that passes all kwargs through - `create_diffusion_cond_from_config()` — factory function From `conditioners.py`: - `Cond_MLP` (type `"cond_mlp"`) — for video_features and text_features. Uses `pad_sequence`, 2-layer MLP, returns `[embeddings, ones_mask]`. During eval with batch<16, doubles batch with null embed for CFG - `Sync_MLP` (type `"sync_mlp"`) — for sync_features with learnable `sync_pos_emb` of shape (1,1,8,dim), reshapes into segments of 8, interpolates to target length - `MultiConditioner` — iterates over `batch_metadata: List[Dict]`, collects per-sample inputs, calls each conditioner. Returns dict of `{key: (tensor, mask)}` - `create_multi_conditioner_from_conditioning_config()` — factory From `autoencoders.py`: - `AudioAutoencoder` — with `encode_audio()` and `decode_audio()` - `OobleckEncoder`, `OobleckDecoder` — with ResidualUnit, snake activation - Dependencies: `alias_free_torch` (SnakeBeta), `dac.nn` (WNConv1d, WNConvTranspose1d) From `pretransforms.py`: - `AutoencoderPretransform` — wraps AudioAutoencoder, `encode()` and `decode()` methods From `bottleneck.py`: - `VAEBottleneck` — reparameterization trick (split mean/logvar, sample) From `blocks.py`: - Any shared blocks used by the above (attention blocks, FeedForward, etc.) From `transformer.py`: - `ContinuousTransformer` — the core transformer with cross-attention, used by DiffusionTransformer From `utils.py`: - `load_ckpt_state_dict()` — handles .safetensors and .ckpt, optional prefix stripping - `remove_weight_norm_from_model()` — used in some inference paths **Step 2: Handle flash-attn gracefully in transformer.py / blocks.py** Replace hard `import flash_attn` with: ```python try: from flash_attn import flash_attn_func HAS_FLASH_ATTN = True except ImportError: HAS_FLASH_ATTN = False ``` In the attention forward pass, use: ```python if HAS_FLASH_ATTN: out = flash_attn_func(q, k, v, ...) else: out = torch.nn.functional.scaled_dot_product_attention(q, k, v, ...) ``` **Step 3: Verify imports resolve** Run: `python -c "from prismaudio_core.factory import create_model_from_config; print('OK')"` from the project root (with ComfyUI's python). Expected: `OK` (or import errors to fix iteratively) **Step 4: Commit** ```bash git add prismaudio_core/models/ git commit -m "feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)" ``` --- ### Task 4: Extract prismaudio_core — Inference/Sampling (with callback fix) **Files:** - Create: `prismaudio_core/inference/__init__.py` - Create: `prismaudio_core/inference/sampling.py` (MODIFIED from `PrismAudio/inference/sampling.py`) - Create: `prismaudio_core/inference/utils.py` (from `PrismAudio/inference/utils.py`) **Step 1: Extract sampling.py WITH callback support added** The original `sample_discrete_euler` uses `tqdm` and has no callback parameter. We MUST copy and modify it to add callback support for ComfyUI progress bars. ```python import torch from tqdm import trange def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args): """Discrete Euler sampler for rectified flow, with optional callback. Modified from PrismAudio to add callback parameter for ComfyUI progress reporting. Original uses tqdm internally. Args: model: The diffusion model (DiTWrapper) x: Initial noise tensor [B, C, T] steps: Number of sampling steps sigma_max: Maximum sigma (default 1.0 for rectified flow) callback: Optional callable({"i": step, "x": current_x}) for progress **extra_args: Passed to model() — includes cross_attn_cond, add_cond, sync_cond, cfg_scale, batch_cfg, etc. """ t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype) for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])): dt = t_next - t_curr t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device) x = x + dt * model(x, t_curr_tensor, **extra_args) if callback is not None: callback({"i": i, "x": x}) return x # Note: sample_rf() and sample() (v-diffusion) are NOT included. # PrismAudio uses rectified_flow objective which only needs sample_discrete_euler. # Including unused samplers with potentially wrong math is a maintenance hazard. ``` **Step 2: Extract inference/utils.py** Keep: - `set_audio_channels(audio, target_channels)` - `prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device)` **Step 3: Verify sampling import** Run: `python -c "from prismaudio_core.inference.sampling import sample_discrete_euler; print('OK')"` Expected: `OK` **Step 4: Commit** ```bash git add prismaudio_core/inference/ git commit -m "feat: extract prismaudio_core inference with callback-enabled sampling" ``` --- ### Task 5: PrismAudioModelLoader Node **Files:** - Create: `nodes/model_loader.py` **Step 1: Write the node** Key design decisions: - No hf_token widget (security risk — saved to workflow JSON). Uses env var / cached login only. - Creates model with default config. Duration-dependent seq lengths handled at sample time. - The model config's `sample_size: 397312` corresponds to ~9s default. For other durations, the Sampler node will update seq lengths on the DiT before each generation. ```python import os import json import torch import folder_paths import comfy.model_management as mm import comfy.utils from .utils import ( PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder, get_device, get_offload_device, determine_precision, determine_offload_strategy, soft_empty_cache, resolve_hf_token, ) # HuggingFace repo for auto-download HF_REPO_ID = "FunAudioLLM/PrismAudio" REQUIRED_FILES = { "diffusion": "prismaudio.ckpt", "vae": "vae.ckpt", "synchformer": "synchformer_state_dict.pth", } def _download_if_missing(filename, model_dir, hf_token=None): """Download a model file from HuggingFace if not present locally.""" filepath = os.path.join(model_dir, filename) if os.path.exists(filepath): return filepath from huggingface_hub import hf_hub_download print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...") try: downloaded = hf_hub_download( repo_id=HF_REPO_ID, filename=filename, local_dir=model_dir, token=hf_token or None, ) return downloaded except Exception as e: if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower(): raise RuntimeError( f"[PrismAudio] Model '{filename}' requires license acceptance. " f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, " f"then set HF_TOKEN env var or run: huggingface-cli login" ) from e raise class PrismAudioModelLoader: @classmethod def INPUT_TYPES(cls): register_model_folder() return { "required": { "precision": (["auto", "fp32", "fp16", "bf16"],), "offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],), }, } RETURN_TYPES = ("PRISMAUDIO_MODEL",) RETURN_NAMES = ("model",) FUNCTION = "load_model" CATEGORY = PRISMAUDIO_CATEGORY def load_model(self, precision, offload_strategy): device = get_device() dtype = determine_precision(precision, device) strategy = determine_offload_strategy(offload_strategy) token = resolve_hf_token() model_dir = get_prismaudio_model_dir() # Auto-download missing files for key, filename in REQUIRED_FILES.items(): _download_if_missing(filename, model_dir, hf_token=token) # Load config config_path = os.path.join( os.path.dirname(os.path.dirname(__file__)), "prismaudio_core", "configs", "prismaudio.json" ) with open(config_path) as f: model_config = json.load(f) # Create model from config from prismaudio_core.factory import create_model_from_config model = create_model_from_config(model_config) # Load diffusion weights diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"]) diffusion_state = comfy.utils.load_torch_file(diffusion_path) # Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...} if "state_dict" in diffusion_state: diffusion_state = diffusion_state["state_dict"] model.load_state_dict(diffusion_state, strict=False) # Load VAE weights separately # Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"]) vae_full_state = comfy.utils.load_torch_file(vae_path) # Strip "autoencoder." prefix from keys vae_state = {} prefix = "autoencoder." for k, v in vae_full_state.items(): if k.startswith(prefix): vae_state[k[len(prefix):]] = v else: vae_state[k] = v model.pretransform.load_state_dict(vae_state) # Apply precision: DiT + conditioners in user-selected dtype, # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16 model.model.to(dtype) # DiTWrapper model.conditioner.to(dtype) # MultiConditioner # model.pretransform stays in fp32 if strategy == "keep_in_vram": model = model.to(device) else: model = model.to(get_offload_device()) model.eval() return ({ "model": model, "dtype": dtype, "strategy": strategy, "config": model_config, "model_dir": model_dir, },) ``` **Step 2: Test that ComfyUI discovers the node** Run ComfyUI and check that "PrismAudio Model Loader" appears in the node list. **Step 3: Commit** ```bash git add nodes/model_loader.py git commit -m "feat: PrismAudioModelLoader node with auto-download and adaptive VRAM" ``` --- ### Task 6: PrismAudioFeatureLoader Node **Files:** - Create: `nodes/feature_loader.py` **Step 1: Write the node** ```python import os import numpy as np import torch from .utils import PRISMAUDIO_CATEGORY # Keys consumed by the conditioners (video_features, text_features, sync_features) # global_video_features and global_text_features are NOT consumed by any conditioner # in the prismaudio.json config — they are unused. REQUIRED_KEYS = [ "video_features", "text_features", "sync_features", ] class PrismAudioFeatureLoader: @classmethod def INPUT_TYPES(cls): return { "required": { "npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}), }, } RETURN_TYPES = ("PRISMAUDIO_FEATURES",) RETURN_NAMES = ("features",) FUNCTION = "load_features" CATEGORY = PRISMAUDIO_CATEGORY def load_features(self, npz_path): if not os.path.exists(npz_path): raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}") data = np.load(npz_path, allow_pickle=True) features = {} for key in REQUIRED_KEYS: if key in data: features[key] = torch.from_numpy(data[key]).float() else: print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros") # Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None # Sync_MLP requires length divisible by 8 (segments of 8 frames) if key == "sync_features": features[key] = torch.zeros(8, 768) else: features[key] = torch.zeros(1, 1024) # Load duration if present if "duration" in data: features["duration"] = float(data["duration"]) return (features,) ``` **Step 2: Commit** ```bash git add nodes/feature_loader.py git commit -m "feat: PrismAudioFeatureLoader node for pre-computed .npz files" ``` --- ### Task 7: PrismAudioFeatureExtractor Node (Subprocess Bridge) **Files:** - Create: `nodes/feature_extractor.py` - Create: `scripts/extract_features.py` - Create: `scripts/environment.yml` **Step 1: Create the conda environment.yml** ```yaml name: prismaudio-extract channels: - conda-forge - defaults dependencies: - python=3.10 - pip - ffmpeg<7 - pip: - torch>=2.6.0 - torchaudio>=2.6.0 - torchvision>=0.21.0 - tensorflow-cpu==2.15.0 - jax - jaxlib - transformers>=4.52.3 - decord - einops>=0.7.0 - numpy - mediapy - git+https://github.com/google-deepmind/videoprism.git ``` **Step 2: Create scripts/extract_features.py** This is a standalone script that: 1. Takes `--video`, `--cot_text`, `--output` arguments 2. Loads VideoPrism, T5-Gemma, Synchformer 3. Extracts features from the video 4. Saves as `.npz` ```python #!/usr/bin/env python3 """ Standalone PrismAudio feature extraction script. Run in a separate conda env with JAX/TF installed. Usage: python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz Setup: conda env create -f environment.yml conda activate prismaudio-extract """ import argparse import os import sys import numpy as np import torch def main(): parser = argparse.ArgumentParser(description="PrismAudio feature extraction") parser.add_argument("--video", required=True, help="Path to input video") parser.add_argument("--cot_text", required=True, help="Chain-of-thought description") parser.add_argument("--output", required=True, help="Output .npz path") parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint") parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON") parser.add_argument("--clip_fps", type=float, default=4.0) parser.add_argument("--clip_size", type=int, default=288) parser.add_argument("--sync_fps", type=float, default=25.0) parser.add_argument("--sync_size", type=int, default=224) args = parser.parse_args() if not os.path.exists(args.video): print(f"Error: Video not found: {args.video}") sys.exit(1) # Import feature extraction utils (requires JAX/TF) from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils import torchvision.transforms as T from decord import VideoReader, cpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize feature extractor feat_utils = FeaturesUtils( vae_config_path=args.vae_config, synchformer_ckpt=args.synchformer_ckpt, device=device, ) # Load and preprocess video vr = VideoReader(args.video, ctx=cpu(0)) fps = vr.get_avg_fps() total_frames = len(vr) duration = total_frames / fps # Extract CLIP frames (4fps, 288x288) clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] clip_indices = [min(i, total_frames - 1) for i in clip_indices] clip_frames = vr.get_batch(clip_indices).asnumpy() clip_transform = T.Compose([ T.ToPILImage(), T.Resize(args.clip_size), T.CenterCrop(args.clip_size), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) # Extract Sync frames (25fps, 224x224) sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] sync_indices = [min(i, total_frames - 1) for i in sync_indices] sync_frames = vr.get_batch(sync_indices).asnumpy() sync_transform = T.Compose([ T.ToPILImage(), T.Resize(args.sync_size), T.CenterCrop(args.sync_size), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) # Extract features print("[PrismAudio] Encoding text with T5-Gemma...") text_features = feat_utils.encode_t5_text([args.cot_text]) print("[PrismAudio] Encoding video with VideoPrism...") global_video_features, video_features, global_text_features = \ feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) print("[PrismAudio] Encoding video with Synchformer...") sync_features = feat_utils.encode_video_with_sync(sync_input) # Save as .npz np.savez( args.output, video_features=video_features.cpu().numpy(), global_video_features=global_video_features.cpu().numpy(), text_features=text_features.cpu().numpy(), global_text_features=global_text_features.cpu().numpy(), sync_features=sync_features.cpu().numpy(), caption_cot=args.cot_text, duration=duration, ) print(f"[PrismAudio] Features saved to {args.output}") if __name__ == "__main__": main() ``` **Step 3: Create the feature extractor node** ```python import os import hashlib import subprocess import tempfile import torch from .utils import PRISMAUDIO_CATEGORY from .feature_loader import PrismAudioFeatureLoader def _hash_inputs(video_tensor, cot_text): """Create a hash of the inputs for caching.""" h = hashlib.sha256() h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed h.update(cot_text.encode()) return h.hexdigest()[:16] def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30): """Save ComfyUI IMAGE tensor [T,H,W,C] to MP4.""" import torchvision.io as tvio # ComfyUI IMAGE is [T,H,W,C] float32 [0,1] frames = (video_tensor * 255).to(torch.uint8) # torchvision write_video expects [T,H,W,C] uint8 tvio.write_video(output_path, frames, fps=fps) class PrismAudioFeatureExtractor: @classmethod def INPUT_TYPES(cls): return { "required": { "video": ("IMAGE",), "caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}), }, "optional": { "python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}), "cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}), "synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}), }, } RETURN_TYPES = ("PRISMAUDIO_FEATURES",) RETURN_NAMES = ("features",) FUNCTION = "extract_features" CATEGORY = PRISMAUDIO_CATEGORY def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""): # Determine cache directory if not cache_dir: cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features") os.makedirs(cache_dir, exist_ok=True) # Check cache cache_hash = _hash_inputs(video, caption_cot) cached_path = os.path.join(cache_dir, f"{cache_hash}.npz") if os.path.exists(cached_path): print(f"[PrismAudio] Using cached features: {cached_path}") loader = PrismAudioFeatureLoader() return loader.load_features(cached_path) # Save video to temp file with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: tmp_video = tmp.name _save_video_tensor_to_mp4(video, tmp_video) # Build subprocess command script_path = os.path.join( os.path.dirname(os.path.dirname(__file__)), "scripts", "extract_features.py" ) cmd = [ python_env, script_path, "--video", tmp_video, "--cot_text", caption_cot, "--output", cached_path, ] if synchformer_ckpt: cmd.extend(["--synchformer_ckpt", synchformer_ckpt]) print(f"[PrismAudio] Extracting features via subprocess...") try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=600, # 10 minute timeout ) if result.returncode != 0: raise RuntimeError( f"[PrismAudio] Feature extraction failed:\n{result.stderr}" ) print(result.stdout) finally: if os.path.exists(tmp_video): os.unlink(tmp_video) # Load the extracted features loader = PrismAudioFeatureLoader() return loader.load_features(cached_path) ``` **Step 4: Commit** ```bash git add nodes/feature_extractor.py scripts/extract_features.py scripts/environment.yml git commit -m "feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env" ``` --- ### Task 8: PrismAudioSampler Node **Files:** - Create: `nodes/sampler.py` **Step 1: Write the sampler node** This is the core node. Key fixes from review: - Metadata is a TUPLE of dicts, not a flat dict - video_exist is torch.tensor, not Python bool - Empty features are zero tensors, not None - Peak normalization before clamp - Sequence lengths set on DiT config before sampling (matching predict.py approach) - No callback kwarg forwarded to model — callback is handled by our modified sample_discrete_euler ```python import torch import comfy.model_management as mm import comfy.utils from .utils import ( PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, get_device, get_offload_device, soft_empty_cache, ) class PrismAudioSampler: @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("PRISMAUDIO_MODEL",), "features": ("PRISMAUDIO_FEATURES",), "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds"}), "steps": ("INT", {"default": 24, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}), "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, } RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) FUNCTION = "generate" CATEGORY = PRISMAUDIO_CATEGORY def generate(self, model, features, duration, steps, cfg_scale, seed): device = get_device() dtype = model["dtype"] strategy = model["strategy"] diffusion = model["model"] # Compute latent dimensions latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) # Note: no seq length config needed — the model adapts to input tensor shapes # dynamically via its transformer architecture. # Determine if video features are present (not all zeros) has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0 # Build metadata as a TUPLE of dicts (one per batch sample) # MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this sample_meta = { "video_features": features["video_features"].to(device, dtype=dtype), "text_features": features["text_features"].to(device, dtype=dtype), "sync_features": features["sync_features"].to(device, dtype=dtype), "video_exist": torch.tensor(has_video), } metadata = (sample_meta,) # Move model to device if offloaded if strategy == "offload_to_cpu": diffusion.model.to(device) diffusion.conditioner.to(device) soft_empty_cache() with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): # Run conditioning conditioning = diffusion.conditioner(metadata, device) # Handle missing video: substitute learned empty embeddings if not has_video: _substitute_empty_features(diffusion, conditioning, device, dtype) # Assemble conditioning inputs for the DiT cond_inputs = diffusion.get_conditioning_inputs(conditioning) # Generate noise from seed (MPS doesn't support torch.Generator) gen_device = "cpu" if device.type == "mps" else device generator = torch.Generator(device=gen_device).manual_seed(seed) noise = torch.randn( [1, IO_CHANNELS, latent_length], generator=generator, device=gen_device, ).to(device=device, dtype=dtype) # Sample with progress bar pbar = comfy.utils.ProgressBar(steps) from prismaudio_core.inference.sampling import sample_discrete_euler def on_step(info): pbar.update(1) fakes = sample_discrete_euler( diffusion.model, noise, steps, callback=on_step, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True, ) # Offload diffusion model and conditioner before VAE decode if strategy == "offload_to_cpu": diffusion.model.to(get_offload_device()) diffusion.conditioner.to(get_offload_device()) soft_empty_cache() diffusion.pretransform.to(device) # VAE decode in fp32 (snake activations overflow in fp16) with torch.amp.autocast(device_type=device.type, enabled=False): audio = diffusion.pretransform.decode(fakes.float()) # Offload VAE if strategy == "offload_to_cpu": diffusion.pretransform.to(get_offload_device()) soft_empty_cache() # Peak normalize then clamp (matching reference: div by max abs before clamp) audio = audio.float() peak = audio.abs().max().clamp(min=1e-8) audio = (audio / peak).clamp(-1, 1) # Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int} return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) def _substitute_empty_features(diffusion, conditioning, device, dtype): """Replace sync conditioning with learned empty embedding when video is absent. Only substitutes sync_features — NOT video_features. The reference code (predict.py/app.py) checks for 'metaclip_features' which doesn't exist in the prismaudio.json config, so video substitution never runs. Cond_MLP with zero input + bias-free linear layers naturally produces near-zero output. The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim]. """ dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model # Only substitute sync_features (matching reference behavior for prismaudio config) if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning: empty = dit.empty_sync_feat.to(device, dtype=dtype) cond_tensor = conditioning['sync_features'][0] batch_size = cond_tensor.shape[0] empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) conditioning['sync_features'][0] = empty_expanded conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device) ``` **Step 2: Verify the node registers** Start ComfyUI, check "PrismAudio Sampler" appears in add-node menu. **Step 3: Commit** ```bash git add nodes/sampler.py git commit -m "feat: PrismAudioSampler node with correct metadata format and peak normalization" ``` --- ### Task 9: PrismAudioTextOnly Node **Files:** - Create: `nodes/text_only.py` **Step 1: Write the text-only node** Key fixes from review: - Uses `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder` - No truncation (matching reference) - Metadata is tuple of dicts with torch.tensor(False) for video_exist - Zero tensors for video/sync features, not None - Peak normalization ```python import torch import comfy.model_management as mm import comfy.utils from .utils import ( PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, get_device, get_offload_device, soft_empty_cache, resolve_hf_token, ) from .sampler import _substitute_empty_features class PrismAudioTextOnly: @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("PRISMAUDIO_MODEL",), "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}), "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), "steps": ("INT", {"default": 24, "min": 1, "max": 100}), "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, } RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) FUNCTION = "generate" CATEGORY = PRISMAUDIO_CATEGORY def generate(self, model, text_prompt, duration, steps, cfg_scale, seed): device = get_device() dtype = model["dtype"] strategy = model["strategy"] diffusion = model["model"] latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) # Encode text with T5-Gemma text_features = _encode_text_t5(text_prompt, device, dtype) # Build metadata: tuple of one dict per sample # Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence) # Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768] # These will be substituted with learned empty embeddings after conditioning sample_meta = { "video_features": torch.zeros(1, 1024, device=device, dtype=dtype), "text_features": text_features.to(device, dtype=dtype), "sync_features": torch.zeros(8, 768, device=device, dtype=dtype), "video_exist": torch.tensor(False), } metadata = (sample_meta,) if strategy == "offload_to_cpu": diffusion.model.to(device) diffusion.conditioner.to(device) soft_empty_cache() with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): conditioning = diffusion.conditioner(metadata, device) # Substitute empty features for video/sync _substitute_empty_features(diffusion, conditioning, device, dtype) cond_inputs = diffusion.get_conditioning_inputs(conditioning) # Generate noise from seed (MPS doesn't support torch.Generator) gen_device = "cpu" if device.type == "mps" else device generator = torch.Generator(device=gen_device).manual_seed(seed) noise = torch.randn( [1, IO_CHANNELS, latent_length], generator=generator, device=gen_device, ).to(device=device, dtype=dtype) pbar = comfy.utils.ProgressBar(steps) from prismaudio_core.inference.sampling import sample_discrete_euler def on_step(info): pbar.update(1) fakes = sample_discrete_euler( diffusion.model, noise, steps, callback=on_step, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True, ) if strategy == "offload_to_cpu": diffusion.model.to(get_offload_device()) diffusion.conditioner.to(get_offload_device()) soft_empty_cache() diffusion.pretransform.to(device) # VAE decode in fp32 (snake activations overflow in fp16) with torch.amp.autocast(device_type=device.type, enabled=False): audio = diffusion.pretransform.decode(fakes.float()) if strategy == "offload_to_cpu": diffusion.pretransform.to(get_offload_device()) soft_empty_cache() # Peak normalize then clamp audio = audio.float() peak = audio.abs().max().clamp(min=1e-8) audio = (audio / peak).clamp(-1, 1) return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) # T5-Gemma encoder singleton _t5_model = None _t5_tokenizer = None def _encode_text_t5(text, device, dtype): """Encode text using T5-Gemma. Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference FeaturesUtils.encode_t5_text() implementation. No truncation applied (matching reference behavior). """ global _t5_model, _t5_tokenizer if _t5_model is None: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM model_id = "google/t5gemma-l-l-ul2-it" token = resolve_hf_token() print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}") _t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) _t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder() _t5_model.eval() _t5_model.to(device, dtype=dtype) tokens = _t5_tokenizer( text, return_tensors="pt", padding=True, ).to(device) with torch.no_grad(): outputs = _t5_model(**tokens) # Move T5 off GPU after encoding to save VRAM _t5_model.to("cpu") soft_empty_cache() return outputs.last_hidden_state.squeeze(0) # [seq_len, dim] ``` **Step 2: Commit** ```bash git add nodes/text_only.py git commit -m "feat: PrismAudioTextOnly node with correct T5-Gemma encoding" ``` --- ### Task 10: Integration Testing & Polish **Files:** - Modify: `nodes/__init__.py` (verify all imports work) - Modify: `__init__.py` (verify top-level registration) **Step 1: Verify all node imports resolve** Run from ComfyUI's Python: ```bash cd /path/to/ComfyUI python -c " import sys sys.path.insert(0, 'custom_nodes/ComfyUI-PrismAudio') from nodes import NODE_CLASS_MAPPINGS print('Registered nodes:', list(NODE_CLASS_MAPPINGS.keys())) " ``` Expected output: ``` Registered nodes: ['PrismAudioModelLoader', 'PrismAudioFeatureLoader', 'PrismAudioFeatureExtractor', 'PrismAudioSampler', 'PrismAudioTextOnly'] ``` **Step 2: Fix any import errors iteratively** Common issues: - `prismaudio_core` internal imports may reference wrong module paths - Missing model submodules in `prismaudio_core/models/` - flash-attn fallback not properly guarded **Step 3: Test model loading (requires GPU + model files)** ```bash python -c " from prismaudio_core.factory import create_model_from_config import json with open('prismaudio_core/configs/prismaudio.json') as f: config = json.load(f) model = create_model_from_config(config) print('Model created, params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M') " ``` Expected: `Model created, params: ~518 M` **Step 4: End-to-end test with pre-computed features** If you have a `.npz` feature file from the PrismAudio repo's demo data, test the full pipeline in ComfyUI: 1. PrismAudioModelLoader -> PrismAudioFeatureLoader -> PrismAudioSampler -> Preview Audio node **Step 5: Verify variable duration handling** Test with multiple durations (5s, 10s, 20s) to ensure the model adapts to different input shapes and produces audio of the expected length. **Step 6: Commit** ```bash git add -A git commit -m "feat: integration fixes and verification" ``` --- ### Task 11: README **Files:** - Create: `README.md` **Step 1: Write README covering:** - What PrismAudio is (brief, link to paper) - Installation (clone, pip install requirements, optional extraction env setup) - Node descriptions with input/output tables - Example workflows (quality path with FeatureExtractor, quick path with FeatureLoader, text-only) - HuggingFace authentication (2 methods: `HF_TOKEN` env var, `huggingface-cli login`) - Note: hf_token is NOT a node widget for security reasons - Which models may be gated (T5-Gemma, potentially Stable Audio VAE) - Model file sizes: diffusion ~2.7GB, VAE ~2.5GB, synchformer ~950MB - Extraction env setup via conda environment.yml - Troubleshooting (VRAM, flash-attn optional, gated models) - Credits and license **Step 2: Commit** ```bash git add README.md git commit -m "docs: README with installation and usage instructions" ``` --- ## Dependency Graph ``` Task 1 (scaffolding) ├── Task 2 (core config + factory) ──┐ │ └── Task 3 (core models) ──────┤ │ └── Task 4 (core sampling)┤ │ ├── Task 5 (ModelLoader node) │ ├── Task 6 (FeatureLoader node) │ ├── Task 7 (FeatureExtractor node) │ ├── Task 8 (Sampler node) │ └── Task 9 (TextOnly node) └────────────────────────────────────────── Task 10 (Integration) └── Task 11 (README) ``` Tasks 5-9 can be parallelized after Task 4 is complete. Task 3 is the heaviest — it involves extracting and adapting ~10 model files.