Files
ComfyUI-SelVA/docs/plans/2026-03-27-comfyui-prismaudio-design.md
T

195 lines
7.7 KiB
Markdown

# ComfyUI-PrismAudio Design Document
**Date:** 2026-03-27
**Status:** Approved
## Overview
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
## Architecture
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
## Project Structure
```
ComfyUI-PrismAudio/
├── __init__.py # Node registration
├── nodes/
│ ├── __init__.py
│ ├── model_loader.py # PrismAudioModelLoader
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
│ ├── sampler.py # PrismAudioSampler
│ ├── text_only.py # PrismAudioTextOnly
│ └── utils.py # Shared helpers
├── prismaudio_core/ # Extracted inference code from PrismAudio
│ ├── __init__.py
│ ├── configs/
│ │ └── prismaudio.json
│ ├── models/ # DiT, conditioners, autoencoders, etc.
│ ├── inference/ # sampling.py, generation.py
│ └── factory.py # create_model_from_config
├── scripts/
│ ├── extract_features.py # Standalone VideoPrism feature extraction
│ └── environment.yml # Conda env for extraction (JAX + TF)
├── requirements.txt # PyTorch-only deps (no JAX/TF)
└── README.md
```
## Nodes
### PrismAudioModelLoader
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
| **Output** | | |
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
**Token resolution order** (no widget — env/CLI only for security):
1. `HF_TOKEN` environment variable
2. `huggingface-cli login` cached token
3. None — fails on gated models with clear error message linking to license page
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
### PrismAudioFeatureLoader
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| npz_path | STRING | Path to .npz file |
| **Output** | | |
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
### PrismAudioFeatureExtractor
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| video | IMAGE | ComfyUI video frames tensor |
| caption_cot | STRING | CoT description text |
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
| **Output** | | |
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
### PrismAudioSampler
Main generation node — takes model + features, produces audio.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| model | PRISMAUDIO_MODEL | From ModelLoader |
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
| cot_description | STRING | Multiline CoT text |
| duration | FLOAT | 1.0-30.0, defaults to video length |
| steps | INT | 1-100, default 24 |
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
| seed | INT | Controls noise generation |
| **Output** | | |
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
**Pipeline:**
1. Encode CoT text via T5-Gemma -> text_features
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
3. Compute latent_seq_len = round(44100 / 2048 * duration)
4. Generate noise [1, 64, latent_seq_len] from seed
5. Discrete Euler sampling (rectified flow) with CFG
6. VAE decode -> stereo waveform at 44100 Hz
7. Normalize to [-1, 1], return as AUDIO
### PrismAudioTextOnly
Text-to-audio without video input.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| model | PRISMAUDIO_MODEL | From ModelLoader |
| text_prompt | STRING | Text description |
| duration | FLOAT | 1.0-30.0 |
| steps | INT | 1-100, default 24 |
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
| seed | INT | Controls noise generation |
| **Output** | | |
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
## VRAM Management
Adaptive strategy using `comfy.model_management`:
| Available VRAM | Behavior |
|---|---|
| 24GB+ | Keep diffusion + VAE in VRAM |
| 12-24GB | Sequential offload between stages |
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
| <8GB | Warn user, attempt with aggressive offload + fp16 |
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
## Feature Extraction Paths
### Path 1: Pre-computed .npz (FeatureLoader)
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
### Path 2: Subprocess bridge (FeatureExtractor)
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
### Path 3: Text-only (TextOnly node)
No video features needed. T5-Gemma text encoding only (PyTorch-native).
## Dependencies
### ComfyUI environment (`requirements.txt`)
```
einops>=0.7.0
safetensors
huggingface_hub
transformers>=4.52.3
k-diffusion>=0.1.1
```
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
### Extraction environment (`scripts/environment.yml`)
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
## Model Files
Stored in `ComfyUI/models/prismaudio/`:
| File | Size | Source |
|------|------|--------|
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
## Design Decisions
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.