From c9364c4ec22ada4bdba676fbb825e6fe16d38f30 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 16:57:15 +0100 Subject: [PATCH] docs: initial design and implementation plan --- .../2026-03-27-comfyui-prismaudio-design.md | 194 +++ ...03-27-comfyui-prismaudio-implementation.md | 1372 +++++++++++++++++ 2 files changed, 1566 insertions(+) create mode 100644 docs/plans/2026-03-27-comfyui-prismaudio-design.md create mode 100644 docs/plans/2026-03-27-comfyui-prismaudio-implementation.md diff --git a/docs/plans/2026-03-27-comfyui-prismaudio-design.md b/docs/plans/2026-03-27-comfyui-prismaudio-design.md new file mode 100644 index 0000000..44ed361 --- /dev/null +++ b/docs/plans/2026-03-27-comfyui-prismaudio-design.md @@ -0,0 +1,194 @@ +# ComfyUI-PrismAudio Design Document + +**Date:** 2026-03-27 +**Status:** Approved + +## Overview + +ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE. + +## Architecture + +**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment. + +## Project Structure + +``` +ComfyUI-PrismAudio/ +├── __init__.py # Node registration +├── nodes/ +│ ├── __init__.py +│ ├── model_loader.py # PrismAudioModelLoader +│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz) +│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge) +│ ├── sampler.py # PrismAudioSampler +│ ├── text_only.py # PrismAudioTextOnly +│ └── utils.py # Shared helpers +├── prismaudio_core/ # Extracted inference code from PrismAudio +│ ├── __init__.py +│ ├── configs/ +│ │ └── prismaudio.json +│ ├── models/ # DiT, conditioners, autoencoders, etc. +│ ├── inference/ # sampling.py, generation.py +│ └── factory.py # create_model_from_config +├── scripts/ +│ ├── extract_features.py # Standalone VideoPrism feature extraction +│ └── environment.yml # Conda env for extraction (JAX + TF) +├── requirements.txt # PyTorch-only deps (no JAX/TF) +└── README.md +``` + +## Nodes + +### PrismAudioModelLoader + +Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally. + +| Field | Type | Details | +|-------|------|---------| +| **Inputs** | | | +| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability | +| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] | +| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | | +| **Output** | | | +| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config | + +**Token resolution order** (no widget — env/CLI only for security): +1. `HF_TOKEN` environment variable +2. `huggingface-cli login` cached token +3. None — fails on gated models with clear error message linking to license page + +**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually. + +### PrismAudioFeatureLoader + +Loads pre-computed `.npz` feature files for maximum quality video-to-audio. + +| Field | Type | Details | +|-------|------|---------| +| **Inputs** | | | +| npz_path | STRING | Path to .npz file | +| **Output** | | | +| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features | + +### PrismAudioFeatureExtractor + +Subprocess bridge — extracts features from video using VideoPrism in an isolated environment. + +| Field | Type | Details | +|-------|------|---------| +| **Inputs** | | | +| video | IMAGE | ComfyUI video frames tensor | +| caption_cot | STRING | CoT description text | +| python_env | STRING | Path to python binary with JAX/TF (default: "python") | +| output_dir | STRING | Cache directory for .npz files (default: temp dir) | +| **Output** | | | +| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output | + +**Caching:** Hashes video + text to avoid re-extraction on repeated runs. + +### PrismAudioSampler + +Main generation node — takes model + features, produces audio. + +| Field | Type | Details | +|-------|------|---------| +| **Inputs** | | | +| model | PRISMAUDIO_MODEL | From ModelLoader | +| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor | +| cot_description | STRING | Multiline CoT text | +| duration | FLOAT | 1.0-30.0, defaults to video length | +| steps | INT | 1-100, default 24 | +| cfg_scale | FLOAT | 1.0-20.0, default 5.0 | +| seed | INT | Controls noise generation | +| **Output** | | | +| audio | AUDIO | {waveform: tensor, sample_rate: 44100} | + +**Pipeline:** +1. Encode CoT text via T5-Gemma -> text_features +2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond) +3. Compute latent_seq_len = round(44100 / 2048 * duration) +4. Generate noise [1, 64, latent_seq_len] from seed +5. Discrete Euler sampling (rectified flow) with CFG +6. VAE decode -> stereo waveform at 44100 Hz +7. Normalize to [-1, 1], return as AUDIO + +### PrismAudioTextOnly + +Text-to-audio without video input. + +| Field | Type | Details | +|-------|------|---------| +| **Inputs** | | | +| model | PRISMAUDIO_MODEL | From ModelLoader | +| text_prompt | STRING | Text description | +| duration | FLOAT | 1.0-30.0 | +| steps | INT | 1-100, default 24 | +| cfg_scale | FLOAT | 1.0-20.0, default 5.0 | +| seed | INT | Controls noise generation | +| **Output** | | | +| audio | AUDIO | {waveform: tensor, sample_rate: 44100} | + +Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt. + +## VRAM Management + +Adaptive strategy using `comfy.model_management`: + +| Available VRAM | Behavior | +|---|---| +| 24GB+ | Keep diffusion + VAE in VRAM | +| 12-24GB | Sequential offload between stages | +| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced | +| <8GB | Warn user, attempt with aggressive offload + fp16 | + +Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()` + +## Feature Extraction Paths + +### Path 1: Pre-computed .npz (FeatureLoader) +User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk. + +### Path 2: Subprocess bridge (FeatureExtractor) +Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash. + +### Path 3: Text-only (TextOnly node) +No video features needed. T5-Gemma text encoding only (PyTorch-native). + +## Dependencies + +### ComfyUI environment (`requirements.txt`) +``` +einops>=0.7.0 +safetensors +huggingface_hub +transformers>=4.52.3 +k-diffusion>=0.1.1 +``` + +flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`. + +### Extraction environment (`scripts/environment.yml`) +Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup. + +## Model Files + +Stored in `ComfyUI/models/prismaudio/`: + +| File | Size | Source | +|------|------|--------| +| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio | +| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio | +| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio | + +T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache. + +Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)` + +## Design Decisions + +- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes. +- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env. +- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL. +- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`. +- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback. diff --git a/docs/plans/2026-03-27-comfyui-prismaudio-implementation.md b/docs/plans/2026-03-27-comfyui-prismaudio-implementation.md new file mode 100644 index 0000000..c1e5daf --- /dev/null +++ b/docs/plans/2026-03-27-comfyui-prismaudio-implementation.md @@ -0,0 +1,1372 @@ +# ComfyUI-PrismAudio Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Build ComfyUI custom nodes for PrismAudio video-to-audio and text-to-audio generation with adaptive VRAM management and isolated feature extraction. + +**Architecture:** Selective code extraction from PrismAudio `prismaudio` branch into `prismaudio_core/` module. 5 ComfyUI nodes (ModelLoader, FeatureLoader, FeatureExtractor, Sampler, TextOnly). Feature extraction via subprocess bridge to isolated JAX/TF environment. Auto-download from HuggingFace with gated model support. + +**Tech Stack:** PyTorch, ComfyUI APIs (folder_paths, comfy.model_management, comfy.utils), HuggingFace Hub, transformers (T5-Gemma), einops, k-diffusion, safetensors + +--- + +## Bug Fixes Applied (from review) + +This plan incorporates fixes for all 14 bugs identified during review: + +1. **sample_discrete_euler callback**: Copy function into prismaudio_core, add callback param to the sampling loop +2. **Metadata format**: Return `(dict,)` tuple, not flat dict — matches `MultiConditioner.forward(batch_metadata: List[Dict])` +3. **video_exist**: Use `torch.tensor(True/False)`, not Python bool +4. **None features**: Use zero tensors of correct shape, never None — `pad_sequence(None)` crashes +5. **update_seq_lengths removed**: Does not exist in source. Model adapts to input shapes dynamically — no seq length config needed +6. **Sequence length config**: Not needed — model handles variable lengths natively via input tensor shapes +7. **T5-Gemma class**: Use `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder` +8. **Peak normalization**: Add `audio / audio.abs().max().clamp(min=1e-8)` before clamp +9. **Empty feature substitution**: Match reference approach — substitute on raw conditioning output with correct shapes +10. **hf_token security**: Remove STRING widget entirely. Rely on env var / `huggingface-cli login` only. Document in README +11. **Synchformer size**: Corrected to ~950MB in docs +12. **T5 truncation**: Match reference — `truncation=False`, no max_length +13. **Remove global_video/text_features from metadata**: Not consumed by any conditioner +14. **Add tqdm to requirements** + +### Bug Fixes Applied (from second review) + +15. **Sync_MLP zero-tensor crash**: Sync zero-tensor fallback must be `[8, 768]` not `[1, 768]` — Sync_MLP does `length // 8` which gives 0 for length=1, causing `F.interpolate` on empty tensor +16. **sample_discrete_euler undefined `i`**: Loop needs `enumerate()` — `for i, (t_curr, t_prev) in enumerate(zip(...))` +17. **_update_seq_lengths removed entirely**: Was a no-op (attributes don't exist on DiT). Model handles variable lengths natively — function deleted +18. **cot_description removed from Sampler**: Was dead code — features already contain pre-computed text_features +19. **Conditioner VRAM leak**: Add `diffusion.conditioner.to(get_offload_device())` after generation in offload path +20. **VAE size corrected**: ~2.52GB, not ~300MB + +### Bug Fixes Applied (from third review) + +21. **Remove video_features substitution**: `_substitute_empty_features` should only substitute sync_features. Reference code checks for `metaclip_features` (wrong key for prismaudio config), so video substitution never runs. Cond_MLP with zero input + bias-free linears naturally produces near-zero output +22. **Remove dead `sample()` and `sample_rf`**: Wrong noise schedule (linear vs cosine), never called for rectified_flow. Only keep `sample_discrete_euler` +23. **VAE decode in fp32**: Keep pretransform in fp32 even when rest of model is fp16/bf16 — snake activations overflow in fp16 +24. **Lazy imports in nodes/__init__.py**: Use try/except to allow incremental development +25. **MPS Generator guard**: `torch.Generator(device="cpu")` on Apple Silicon, move noise to device after +26. **Use comfy.utils.load_torch_file for VAE**: Consistent with diffusion loading, handles PyTorch 2.6+ weights_only default +27. **Task 10 stale reference**: Remove mention of `_update_seq_lengths` + +### Bug Fixes Applied (from fourth review) + +28. **TextOnly missing MPS guard**: Fix-on-fix regression — MPS Generator guard was applied to Sampler but not TextOnly +29. **TextOnly noise dtype**: Was passing dtype to torch.randn directly (fp16 noise), now generates fp32 then converts (matching Sampler) +30. **Sync substitution seq length**: Low-severity divergence from reference, accepted (DiT handles variable-length sync_cond) + +--- + +### Task 1: Project Scaffolding + +**Files:** +- Create: `__init__.py` +- Create: `nodes/__init__.py` +- Create: `nodes/utils.py` +- Create: `requirements.txt` + +**Step 1: Create requirements.txt** + +``` +einops>=0.7.0 +safetensors +huggingface_hub +transformers>=4.52.3 +k-diffusion>=0.1.1 +alias-free-torch +descript-audio-codec +tqdm +``` + +**Step 2: Create nodes/utils.py with shared helpers** + +```python +import os +import torch +import folder_paths +import comfy.model_management as mm + +PRISMAUDIO_CATEGORY = "PrismAudio" +SAMPLE_RATE = 44100 +DOWNSAMPLING_RATIO = 2048 +IO_CHANNELS = 64 + +def get_prismaudio_model_dir(): + """Get or create the prismaudio model directory.""" + model_dir = os.path.join(folder_paths.models_dir, "prismaudio") + os.makedirs(model_dir, exist_ok=True) + return model_dir + +def register_model_folder(): + """Register prismaudio model folder with ComfyUI.""" + model_dir = get_prismaudio_model_dir() + folder_paths.add_model_folder_path("prismaudio", model_dir) + +def get_device(): + return mm.get_torch_device() + +def get_offload_device(): + return mm.unet_offload_device() + +def get_free_memory(device=None): + if device is None: + device = get_device() + return mm.get_free_memory(device) + +def soft_empty_cache(): + mm.soft_empty_cache() + +def determine_precision(preference, device): + """Determine the best precision for the given device.""" + if preference != "auto": + return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference] + if device.type == "cpu": + return torch.float32 + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + +def determine_offload_strategy(preference): + """Determine offload strategy based on available VRAM.""" + if preference != "auto": + return preference + free_mem = get_free_memory() + gb = free_mem / (1024 ** 3) + if gb >= 24: + return "keep_in_vram" + else: + return "offload_to_cpu" + +def try_import_flash_attn(): + """Try to import flash attention, return None if unavailable.""" + try: + import flash_attn + return flash_attn + except ImportError: + return None + +def resolve_hf_token(): + """Resolve HF token from env var or cached login. No widget — security risk.""" + env_token = os.environ.get("HF_TOKEN") + if env_token: + return env_token + # huggingface_hub will use cached token automatically if None is passed + return None +``` + +**Step 3: Create nodes/__init__.py** + +```python +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +# Lazy imports — allows incremental development (nodes can be added one at a time) +_NODES = { + "PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"), + "PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"), + "PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"), + "PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"), + "PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"), +} + +for key, (module_path, class_name, display_name) in _NODES.items(): + try: + import importlib + mod = importlib.import_module(module_path, package=__name__) + NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name) + NODE_DISPLAY_NAME_MAPPINGS[key] = display_name + except (ImportError, AttributeError) as e: + print(f"[PrismAudio] Skipping {key}: {e}") +``` + +**Step 4: Create top-level __init__.py** + +```python +""" +ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026). +""" +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] +``` + +**Step 5: Commit** + +```bash +git add __init__.py nodes/__init__.py nodes/utils.py requirements.txt +git commit -m "feat: project scaffolding with shared utils and node registration" +``` + +--- + +### Task 2: Extract prismaudio_core — Model Config + Factory + +**Files:** +- Create: `prismaudio_core/__init__.py` +- Create: `prismaudio_core/configs/prismaudio.json` (copy from PrismAudio repo) +- Create: `prismaudio_core/factory.py` (adapted from `PrismAudio/models/factory.py`) + +**Step 1: Create prismaudio_core/__init__.py** + +```python +""" +PrismAudio core inference modules. +Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch). +Only inference-critical code — no training, no JAX/TF dependencies. +""" +``` + +**Step 2: Copy prismaudio.json config** + +Fetch from `https://raw.githubusercontent.com/FunAudioLLM/ThinkSound/prismaudio/PrismAudio/configs/model_configs/prismaudio.json` and save to `prismaudio_core/configs/prismaudio.json`. This is a JSON config file with no code — copy verbatim. + +**Step 3: Create factory.py** + +Extract from `PrismAudio/models/factory.py`. Keep only these functions (remove training-related code): +- `create_model_from_config(model_config)` — entry point +- `create_diffusion_cond_from_config(config)` — creates the full model +- `create_pretransform_from_config(pretransform_config, sample_rate)` — VAE +- `create_autoencoder_from_config(config)` — AudioAutoencoder +- `create_bottleneck_from_config(config)` — VAEBottleneck +- `create_multi_conditioner_from_conditioning_config(config)` — conditioners + +All imports should reference `prismaudio_core.models.*` instead of `PrismAudio.models.*`. + +**Step 4: Commit** + +```bash +git add prismaudio_core/ +git commit -m "feat: extract prismaudio_core config and model factory" +``` + +--- + +### Task 3: Extract prismaudio_core — Model Modules + +**Files:** +- Create: `prismaudio_core/models/__init__.py` +- Create: `prismaudio_core/models/dit.py` (from `PrismAudio/models/dit.py`) +- Create: `prismaudio_core/models/diffusion.py` (from `PrismAudio/models/diffusion.py`) +- Create: `prismaudio_core/models/conditioners.py` (from `PrismAudio/models/conditioners.py`) +- Create: `prismaudio_core/models/autoencoders.py` (from `PrismAudio/models/autoencoders.py`) +- Create: `prismaudio_core/models/pretransforms.py` (from `PrismAudio/models/pretransforms.py`) +- Create: `prismaudio_core/models/blocks.py` (from `PrismAudio/models/blocks.py`) +- Create: `prismaudio_core/models/utils.py` (from `PrismAudio/models/utils.py`) +- Create: `prismaudio_core/models/bottleneck.py` (from `PrismAudio/models/bottleneck.py`) +- Create: `prismaudio_core/models/transformer.py` (from `PrismAudio/models/transformer.py`) +- Create: `prismaudio_core/models/local_attention.py` (if used by transformer) + +**Step 1: Extract model files** + +For each file, extract from the PrismAudio repo. Key modifications: +- Change all internal imports from `PrismAudio.models.*` to `prismaudio_core.models.*` +- Remove training-only code (loss functions, training step methods, gradient checkpointing setup) +- Keep all inference paths intact + +**Critical classes to preserve:** + +From `dit.py`: +- `DiffusionTransformer` — full class with `forward()`, CFG logic, conditioning assembly +- `FourierFeatures` — timestep embedding +- Keep `empty_clip_feat` and `empty_sync_feat` learned parameters (nn.Parameter, zeros) + +From `diffusion.py`: +- `ConditionedDiffusionModelWrapper` — with `get_conditioning_inputs()` and routing logic +- `DiTWrapper` — thin wrapper that passes all kwargs through +- `create_diffusion_cond_from_config()` — factory function + +From `conditioners.py`: +- `Cond_MLP` (type `"cond_mlp"`) — for video_features and text_features. Uses `pad_sequence`, 2-layer MLP, returns `[embeddings, ones_mask]`. During eval with batch<16, doubles batch with null embed for CFG +- `Sync_MLP` (type `"sync_mlp"`) — for sync_features with learnable `sync_pos_emb` of shape (1,1,8,dim), reshapes into segments of 8, interpolates to target length +- `MultiConditioner` — iterates over `batch_metadata: List[Dict]`, collects per-sample inputs, calls each conditioner. Returns dict of `{key: (tensor, mask)}` +- `create_multi_conditioner_from_conditioning_config()` — factory + +From `autoencoders.py`: +- `AudioAutoencoder` — with `encode_audio()` and `decode_audio()` +- `OobleckEncoder`, `OobleckDecoder` — with ResidualUnit, snake activation +- Dependencies: `alias_free_torch` (SnakeBeta), `dac.nn` (WNConv1d, WNConvTranspose1d) + +From `pretransforms.py`: +- `AutoencoderPretransform` — wraps AudioAutoencoder, `encode()` and `decode()` methods + +From `bottleneck.py`: +- `VAEBottleneck` — reparameterization trick (split mean/logvar, sample) + +From `blocks.py`: +- Any shared blocks used by the above (attention blocks, FeedForward, etc.) + +From `transformer.py`: +- `ContinuousTransformer` — the core transformer with cross-attention, used by DiffusionTransformer + +From `utils.py`: +- `load_ckpt_state_dict()` — handles .safetensors and .ckpt, optional prefix stripping +- `remove_weight_norm_from_model()` — used in some inference paths + +**Step 2: Handle flash-attn gracefully in transformer.py / blocks.py** + +Replace hard `import flash_attn` with: +```python +try: + from flash_attn import flash_attn_func + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False +``` + +In the attention forward pass, use: +```python +if HAS_FLASH_ATTN: + out = flash_attn_func(q, k, v, ...) +else: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, ...) +``` + +**Step 3: Verify imports resolve** + +Run: `python -c "from prismaudio_core.factory import create_model_from_config; print('OK')"` from the project root (with ComfyUI's python). + +Expected: `OK` (or import errors to fix iteratively) + +**Step 4: Commit** + +```bash +git add prismaudio_core/models/ +git commit -m "feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)" +``` + +--- + +### Task 4: Extract prismaudio_core — Inference/Sampling (with callback fix) + +**Files:** +- Create: `prismaudio_core/inference/__init__.py` +- Create: `prismaudio_core/inference/sampling.py` (MODIFIED from `PrismAudio/inference/sampling.py`) +- Create: `prismaudio_core/inference/utils.py` (from `PrismAudio/inference/utils.py`) + +**Step 1: Extract sampling.py WITH callback support added** + +The original `sample_discrete_euler` uses `tqdm` and has no callback parameter. +We MUST copy and modify it to add callback support for ComfyUI progress bars. + +```python +import torch +from tqdm import trange + +def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args): + """Discrete Euler sampler for rectified flow, with optional callback. + + Modified from PrismAudio to add callback parameter for ComfyUI progress reporting. + Original uses tqdm internally. + + Args: + model: The diffusion model (DiTWrapper) + x: Initial noise tensor [B, C, T] + steps: Number of sampling steps + sigma_max: Maximum sigma (default 1.0 for rectified flow) + callback: Optional callable({"i": step, "x": current_x}) for progress + **extra_args: Passed to model() — includes cross_attn_cond, add_cond, + sync_cond, cfg_scale, batch_cfg, etc. + """ + t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype) + + for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])): + dt = t_next - t_curr + t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device) + x = x + dt * model(x, t_curr_tensor, **extra_args) + if callback is not None: + callback({"i": i, "x": x}) + + return x + + + +# Note: sample_rf() and sample() (v-diffusion) are NOT included. +# PrismAudio uses rectified_flow objective which only needs sample_discrete_euler. +# Including unused samplers with potentially wrong math is a maintenance hazard. +``` + +**Step 2: Extract inference/utils.py** + +Keep: +- `set_audio_channels(audio, target_channels)` +- `prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device)` + +**Step 3: Verify sampling import** + +Run: `python -c "from prismaudio_core.inference.sampling import sample_discrete_euler; print('OK')"` + +Expected: `OK` + +**Step 4: Commit** + +```bash +git add prismaudio_core/inference/ +git commit -m "feat: extract prismaudio_core inference with callback-enabled sampling" +``` + +--- + +### Task 5: PrismAudioModelLoader Node + +**Files:** +- Create: `nodes/model_loader.py` + +**Step 1: Write the node** + +Key design decisions: +- No hf_token widget (security risk — saved to workflow JSON). Uses env var / cached login only. +- Creates model with default config. Duration-dependent seq lengths handled at sample time. +- The model config's `sample_size: 397312` corresponds to ~9s default. For other durations, + the Sampler node will update seq lengths on the DiT before each generation. + +```python +import os +import json +import torch +import folder_paths +import comfy.model_management as mm +import comfy.utils + +from .utils import ( + PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder, + get_device, get_offload_device, determine_precision, determine_offload_strategy, + soft_empty_cache, resolve_hf_token, +) + +# HuggingFace repo for auto-download +HF_REPO_ID = "FunAudioLLM/PrismAudio" +REQUIRED_FILES = { + "diffusion": "prismaudio.ckpt", + "vae": "vae.ckpt", + "synchformer": "synchformer_state_dict.pth", +} + + +def _download_if_missing(filename, model_dir, hf_token=None): + """Download a model file from HuggingFace if not present locally.""" + filepath = os.path.join(model_dir, filename) + if os.path.exists(filepath): + return filepath + + from huggingface_hub import hf_hub_download + print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...") + try: + downloaded = hf_hub_download( + repo_id=HF_REPO_ID, + filename=filename, + local_dir=model_dir, + token=hf_token or None, + ) + return downloaded + except Exception as e: + if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower(): + raise RuntimeError( + f"[PrismAudio] Model '{filename}' requires license acceptance. " + f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, " + f"then set HF_TOKEN env var or run: huggingface-cli login" + ) from e + raise + + +class PrismAudioModelLoader: + @classmethod + def INPUT_TYPES(cls): + register_model_folder() + return { + "required": { + "precision": (["auto", "fp32", "fp16", "bf16"],), + "offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],), + }, + } + + RETURN_TYPES = ("PRISMAUDIO_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = PRISMAUDIO_CATEGORY + + def load_model(self, precision, offload_strategy): + device = get_device() + dtype = determine_precision(precision, device) + strategy = determine_offload_strategy(offload_strategy) + token = resolve_hf_token() + model_dir = get_prismaudio_model_dir() + + # Auto-download missing files + for key, filename in REQUIRED_FILES.items(): + _download_if_missing(filename, model_dir, hf_token=token) + + # Load config + config_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "prismaudio_core", "configs", "prismaudio.json" + ) + with open(config_path) as f: + model_config = json.load(f) + + # Create model from config + from prismaudio_core.factory import create_model_from_config + model = create_model_from_config(model_config) + + # Load diffusion weights + diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"]) + diffusion_state = comfy.utils.load_torch_file(diffusion_path) + # Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...} + if "state_dict" in diffusion_state: + diffusion_state = diffusion_state["state_dict"] + model.load_state_dict(diffusion_state, strict=False) + + # Load VAE weights separately + # Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat + vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"]) + vae_full_state = comfy.utils.load_torch_file(vae_path) + # Strip "autoencoder." prefix from keys + vae_state = {} + prefix = "autoencoder." + for k, v in vae_full_state.items(): + if k.startswith(prefix): + vae_state[k[len(prefix):]] = v + else: + vae_state[k] = v + model.pretransform.load_state_dict(vae_state) + + # Apply precision: DiT + conditioners in user-selected dtype, + # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16 + model.model.to(dtype) # DiTWrapper + model.conditioner.to(dtype) # MultiConditioner + # model.pretransform stays in fp32 + + if strategy == "keep_in_vram": + model = model.to(device) + else: + model = model.to(get_offload_device()) + + model.eval() + + return ({ + "model": model, + "dtype": dtype, + "strategy": strategy, + "config": model_config, + "model_dir": model_dir, + },) +``` + +**Step 2: Test that ComfyUI discovers the node** + +Run ComfyUI and check that "PrismAudio Model Loader" appears in the node list. + +**Step 3: Commit** + +```bash +git add nodes/model_loader.py +git commit -m "feat: PrismAudioModelLoader node with auto-download and adaptive VRAM" +``` + +--- + +### Task 6: PrismAudioFeatureLoader Node + +**Files:** +- Create: `nodes/feature_loader.py` + +**Step 1: Write the node** + +```python +import os +import numpy as np +import torch +from .utils import PRISMAUDIO_CATEGORY + +# Keys consumed by the conditioners (video_features, text_features, sync_features) +# global_video_features and global_text_features are NOT consumed by any conditioner +# in the prismaudio.json config — they are unused. +REQUIRED_KEYS = [ + "video_features", + "text_features", + "sync_features", +] + + +class PrismAudioFeatureLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}), + }, + } + + RETURN_TYPES = ("PRISMAUDIO_FEATURES",) + RETURN_NAMES = ("features",) + FUNCTION = "load_features" + CATEGORY = PRISMAUDIO_CATEGORY + + def load_features(self, npz_path): + if not os.path.exists(npz_path): + raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}") + + data = np.load(npz_path, allow_pickle=True) + + features = {} + for key in REQUIRED_KEYS: + if key in data: + features[key] = torch.from_numpy(data[key]).float() + else: + print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros") + # Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None + # Sync_MLP requires length divisible by 8 (segments of 8 frames) + if key == "sync_features": + features[key] = torch.zeros(8, 768) + else: + features[key] = torch.zeros(1, 1024) + + # Load duration if present + if "duration" in data: + features["duration"] = float(data["duration"]) + + return (features,) +``` + +**Step 2: Commit** + +```bash +git add nodes/feature_loader.py +git commit -m "feat: PrismAudioFeatureLoader node for pre-computed .npz files" +``` + +--- + +### Task 7: PrismAudioFeatureExtractor Node (Subprocess Bridge) + +**Files:** +- Create: `nodes/feature_extractor.py` +- Create: `scripts/extract_features.py` +- Create: `scripts/environment.yml` + +**Step 1: Create the conda environment.yml** + +```yaml +name: prismaudio-extract +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - ffmpeg<7 + - pip: + - torch>=2.6.0 + - torchaudio>=2.6.0 + - torchvision>=0.21.0 + - tensorflow-cpu==2.15.0 + - jax + - jaxlib + - transformers>=4.52.3 + - decord + - einops>=0.7.0 + - numpy + - mediapy + - git+https://github.com/google-deepmind/videoprism.git +``` + +**Step 2: Create scripts/extract_features.py** + +This is a standalone script that: +1. Takes `--video`, `--cot_text`, `--output` arguments +2. Loads VideoPrism, T5-Gemma, Synchformer +3. Extracts features from the video +4. Saves as `.npz` + +```python +#!/usr/bin/env python3 +""" +Standalone PrismAudio feature extraction script. +Run in a separate conda env with JAX/TF installed. + +Usage: + python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz + +Setup: + conda env create -f environment.yml + conda activate prismaudio-extract +""" + +import argparse +import os +import sys +import numpy as np +import torch + + +def main(): + parser = argparse.ArgumentParser(description="PrismAudio feature extraction") + parser.add_argument("--video", required=True, help="Path to input video") + parser.add_argument("--cot_text", required=True, help="Chain-of-thought description") + parser.add_argument("--output", required=True, help="Output .npz path") + parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint") + parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON") + parser.add_argument("--clip_fps", type=float, default=4.0) + parser.add_argument("--clip_size", type=int, default=288) + parser.add_argument("--sync_fps", type=float, default=25.0) + parser.add_argument("--sync_size", type=int, default=224) + args = parser.parse_args() + + if not os.path.exists(args.video): + print(f"Error: Video not found: {args.video}") + sys.exit(1) + + # Import feature extraction utils (requires JAX/TF) + from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils + import torchvision.transforms as T + from decord import VideoReader, cpu + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Initialize feature extractor + feat_utils = FeaturesUtils( + vae_config_path=args.vae_config, + synchformer_ckpt=args.synchformer_ckpt, + device=device, + ) + + # Load and preprocess video + vr = VideoReader(args.video, ctx=cpu(0)) + fps = vr.get_avg_fps() + total_frames = len(vr) + duration = total_frames / fps + + # Extract CLIP frames (4fps, 288x288) + clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] + clip_indices = [min(i, total_frames - 1) for i in clip_indices] + clip_frames = vr.get_batch(clip_indices).asnumpy() + + clip_transform = T.Compose([ + T.ToPILImage(), + T.Resize(args.clip_size), + T.CenterCrop(args.clip_size), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) + + # Extract Sync frames (25fps, 224x224) + sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] + sync_indices = [min(i, total_frames - 1) for i in sync_indices] + sync_frames = vr.get_batch(sync_indices).asnumpy() + + sync_transform = T.Compose([ + T.ToPILImage(), + T.Resize(args.sync_size), + T.CenterCrop(args.sync_size), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) + + # Extract features + print("[PrismAudio] Encoding text with T5-Gemma...") + text_features = feat_utils.encode_t5_text([args.cot_text]) + + print("[PrismAudio] Encoding video with VideoPrism...") + global_video_features, video_features, global_text_features = \ + feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) + + print("[PrismAudio] Encoding video with Synchformer...") + sync_features = feat_utils.encode_video_with_sync(sync_input) + + # Save as .npz + np.savez( + args.output, + video_features=video_features.cpu().numpy(), + global_video_features=global_video_features.cpu().numpy(), + text_features=text_features.cpu().numpy(), + global_text_features=global_text_features.cpu().numpy(), + sync_features=sync_features.cpu().numpy(), + caption_cot=args.cot_text, + duration=duration, + ) + print(f"[PrismAudio] Features saved to {args.output}") + + +if __name__ == "__main__": + main() +``` + +**Step 3: Create the feature extractor node** + +```python +import os +import hashlib +import subprocess +import tempfile +import torch + +from .utils import PRISMAUDIO_CATEGORY +from .feature_loader import PrismAudioFeatureLoader + + +def _hash_inputs(video_tensor, cot_text): + """Create a hash of the inputs for caching.""" + h = hashlib.sha256() + h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed + h.update(cot_text.encode()) + return h.hexdigest()[:16] + + +def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30): + """Save ComfyUI IMAGE tensor [T,H,W,C] to MP4.""" + import torchvision.io as tvio + # ComfyUI IMAGE is [T,H,W,C] float32 [0,1] + frames = (video_tensor * 255).to(torch.uint8) + # torchvision write_video expects [T,H,W,C] uint8 + tvio.write_video(output_path, frames, fps=fps) + + +class PrismAudioFeatureExtractor: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": ("IMAGE",), + "caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}), + }, + "optional": { + "python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}), + "cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}), + "synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}), + }, + } + + RETURN_TYPES = ("PRISMAUDIO_FEATURES",) + RETURN_NAMES = ("features",) + FUNCTION = "extract_features" + CATEGORY = PRISMAUDIO_CATEGORY + + def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""): + # Determine cache directory + if not cache_dir: + cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features") + os.makedirs(cache_dir, exist_ok=True) + + # Check cache + cache_hash = _hash_inputs(video, caption_cot) + cached_path = os.path.join(cache_dir, f"{cache_hash}.npz") + if os.path.exists(cached_path): + print(f"[PrismAudio] Using cached features: {cached_path}") + loader = PrismAudioFeatureLoader() + return loader.load_features(cached_path) + + # Save video to temp file + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: + tmp_video = tmp.name + _save_video_tensor_to_mp4(video, tmp_video) + + # Build subprocess command + script_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "scripts", "extract_features.py" + ) + + cmd = [ + python_env, + script_path, + "--video", tmp_video, + "--cot_text", caption_cot, + "--output", cached_path, + ] + if synchformer_ckpt: + cmd.extend(["--synchformer_ckpt", synchformer_ckpt]) + + print(f"[PrismAudio] Extracting features via subprocess...") + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=600, # 10 minute timeout + ) + if result.returncode != 0: + raise RuntimeError( + f"[PrismAudio] Feature extraction failed:\n{result.stderr}" + ) + print(result.stdout) + finally: + if os.path.exists(tmp_video): + os.unlink(tmp_video) + + # Load the extracted features + loader = PrismAudioFeatureLoader() + return loader.load_features(cached_path) +``` + +**Step 4: Commit** + +```bash +git add nodes/feature_extractor.py scripts/extract_features.py scripts/environment.yml +git commit -m "feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env" +``` + +--- + +### Task 8: PrismAudioSampler Node + +**Files:** +- Create: `nodes/sampler.py` + +**Step 1: Write the sampler node** + +This is the core node. Key fixes from review: +- Metadata is a TUPLE of dicts, not a flat dict +- video_exist is torch.tensor, not Python bool +- Empty features are zero tensors, not None +- Peak normalization before clamp +- Sequence lengths set on DiT config before sampling (matching predict.py approach) +- No callback kwarg forwarded to model — callback is handled by our modified sample_discrete_euler + +```python +import torch +import comfy.model_management as mm +import comfy.utils + +from .utils import ( + PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, + get_device, get_offload_device, soft_empty_cache, +) + + +class PrismAudioSampler: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("PRISMAUDIO_MODEL",), + "features": ("PRISMAUDIO_FEATURES",), + "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds"}), + "steps": ("INT", {"default": 24, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}), + "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + FUNCTION = "generate" + CATEGORY = PRISMAUDIO_CATEGORY + + def generate(self, model, features, duration, steps, cfg_scale, seed): + device = get_device() + dtype = model["dtype"] + strategy = model["strategy"] + diffusion = model["model"] + + # Compute latent dimensions + latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) + + # Note: no seq length config needed — the model adapts to input tensor shapes + # dynamically via its transformer architecture. + + # Determine if video features are present (not all zeros) + has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0 + + # Build metadata as a TUPLE of dicts (one per batch sample) + # MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this + sample_meta = { + "video_features": features["video_features"].to(device, dtype=dtype), + "text_features": features["text_features"].to(device, dtype=dtype), + "sync_features": features["sync_features"].to(device, dtype=dtype), + "video_exist": torch.tensor(has_video), + } + metadata = (sample_meta,) + + # Move model to device if offloaded + if strategy == "offload_to_cpu": + diffusion.model.to(device) + diffusion.conditioner.to(device) + soft_empty_cache() + + with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): + # Run conditioning + conditioning = diffusion.conditioner(metadata, device) + + # Handle missing video: substitute learned empty embeddings + if not has_video: + _substitute_empty_features(diffusion, conditioning, device, dtype) + + # Assemble conditioning inputs for the DiT + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + + # Generate noise from seed (MPS doesn't support torch.Generator) + gen_device = "cpu" if device.type == "mps" else device + generator = torch.Generator(device=gen_device).manual_seed(seed) + noise = torch.randn( + [1, IO_CHANNELS, latent_length], + generator=generator, + device=gen_device, + ).to(device=device, dtype=dtype) + + # Sample with progress bar + pbar = comfy.utils.ProgressBar(steps) + + from prismaudio_core.inference.sampling import sample_discrete_euler + + def on_step(info): + pbar.update(1) + + fakes = sample_discrete_euler( + diffusion.model, + noise, + steps, + callback=on_step, + **cond_inputs, + cfg_scale=cfg_scale, + batch_cfg=True, + ) + + # Offload diffusion model and conditioner before VAE decode + if strategy == "offload_to_cpu": + diffusion.model.to(get_offload_device()) + diffusion.conditioner.to(get_offload_device()) + soft_empty_cache() + diffusion.pretransform.to(device) + + # VAE decode in fp32 (snake activations overflow in fp16) + with torch.amp.autocast(device_type=device.type, enabled=False): + audio = diffusion.pretransform.decode(fakes.float()) + + # Offload VAE + if strategy == "offload_to_cpu": + diffusion.pretransform.to(get_offload_device()) + soft_empty_cache() + + # Peak normalize then clamp (matching reference: div by max abs before clamp) + audio = audio.float() + peak = audio.abs().max().clamp(min=1e-8) + audio = (audio / peak).clamp(-1, 1) + + # Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int} + return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) + + +def _substitute_empty_features(diffusion, conditioning, device, dtype): + """Replace sync conditioning with learned empty embedding when video is absent. + + Only substitutes sync_features — NOT video_features. The reference code + (predict.py/app.py) checks for 'metaclip_features' which doesn't exist in the + prismaudio.json config, so video substitution never runs. Cond_MLP with zero + input + bias-free linear layers naturally produces near-zero output. + + The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim]. + """ + dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model + + # Only substitute sync_features (matching reference behavior for prismaudio config) + if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning: + empty = dit.empty_sync_feat.to(device, dtype=dtype) + cond_tensor = conditioning['sync_features'][0] + batch_size = cond_tensor.shape[0] + empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) + conditioning['sync_features'][0] = empty_expanded + conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device) +``` + +**Step 2: Verify the node registers** + +Start ComfyUI, check "PrismAudio Sampler" appears in add-node menu. + +**Step 3: Commit** + +```bash +git add nodes/sampler.py +git commit -m "feat: PrismAudioSampler node with correct metadata format and peak normalization" +``` + +--- + +### Task 9: PrismAudioTextOnly Node + +**Files:** +- Create: `nodes/text_only.py` + +**Step 1: Write the text-only node** + +Key fixes from review: +- Uses `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder` +- No truncation (matching reference) +- Metadata is tuple of dicts with torch.tensor(False) for video_exist +- Zero tensors for video/sync features, not None +- Peak normalization + +```python +import torch +import comfy.model_management as mm +import comfy.utils + +from .utils import ( + PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, + get_device, get_offload_device, soft_empty_cache, resolve_hf_token, +) +from .sampler import _substitute_empty_features + + +class PrismAudioTextOnly: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("PRISMAUDIO_MODEL",), + "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}), + "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), + "steps": ("INT", {"default": 24, "min": 1, "max": 100}), + "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + FUNCTION = "generate" + CATEGORY = PRISMAUDIO_CATEGORY + + def generate(self, model, text_prompt, duration, steps, cfg_scale, seed): + device = get_device() + dtype = model["dtype"] + strategy = model["strategy"] + diffusion = model["model"] + + latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) + + # Encode text with T5-Gemma + text_features = _encode_text_t5(text_prompt, device, dtype) + + # Build metadata: tuple of one dict per sample + # Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence) + # Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768] + # These will be substituted with learned empty embeddings after conditioning + sample_meta = { + "video_features": torch.zeros(1, 1024, device=device, dtype=dtype), + "text_features": text_features.to(device, dtype=dtype), + "sync_features": torch.zeros(8, 768, device=device, dtype=dtype), + "video_exist": torch.tensor(False), + } + metadata = (sample_meta,) + + if strategy == "offload_to_cpu": + diffusion.model.to(device) + diffusion.conditioner.to(device) + soft_empty_cache() + + with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): + conditioning = diffusion.conditioner(metadata, device) + + # Substitute empty features for video/sync + _substitute_empty_features(diffusion, conditioning, device, dtype) + + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + + # Generate noise from seed (MPS doesn't support torch.Generator) + gen_device = "cpu" if device.type == "mps" else device + generator = torch.Generator(device=gen_device).manual_seed(seed) + noise = torch.randn( + [1, IO_CHANNELS, latent_length], + generator=generator, + device=gen_device, + ).to(device=device, dtype=dtype) + + pbar = comfy.utils.ProgressBar(steps) + + from prismaudio_core.inference.sampling import sample_discrete_euler + + def on_step(info): + pbar.update(1) + + fakes = sample_discrete_euler( + diffusion.model, + noise, + steps, + callback=on_step, + **cond_inputs, + cfg_scale=cfg_scale, + batch_cfg=True, + ) + + if strategy == "offload_to_cpu": + diffusion.model.to(get_offload_device()) + diffusion.conditioner.to(get_offload_device()) + soft_empty_cache() + diffusion.pretransform.to(device) + + # VAE decode in fp32 (snake activations overflow in fp16) + with torch.amp.autocast(device_type=device.type, enabled=False): + audio = diffusion.pretransform.decode(fakes.float()) + + if strategy == "offload_to_cpu": + diffusion.pretransform.to(get_offload_device()) + soft_empty_cache() + + # Peak normalize then clamp + audio = audio.float() + peak = audio.abs().max().clamp(min=1e-8) + audio = (audio / peak).clamp(-1, 1) + + return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) + + +# T5-Gemma encoder singleton +_t5_model = None +_t5_tokenizer = None + + +def _encode_text_t5(text, device, dtype): + """Encode text using T5-Gemma. + + Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference + FeaturesUtils.encode_t5_text() implementation. + No truncation applied (matching reference behavior). + """ + global _t5_model, _t5_tokenizer + + if _t5_model is None: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + model_id = "google/t5gemma-l-l-ul2-it" + token = resolve_hf_token() + print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}") + _t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) + _t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder() + _t5_model.eval() + + _t5_model.to(device, dtype=dtype) + + tokens = _t5_tokenizer( + text, + return_tensors="pt", + padding=True, + ).to(device) + + with torch.no_grad(): + outputs = _t5_model(**tokens) + + # Move T5 off GPU after encoding to save VRAM + _t5_model.to("cpu") + soft_empty_cache() + + return outputs.last_hidden_state.squeeze(0) # [seq_len, dim] +``` + +**Step 2: Commit** + +```bash +git add nodes/text_only.py +git commit -m "feat: PrismAudioTextOnly node with correct T5-Gemma encoding" +``` + +--- + +### Task 10: Integration Testing & Polish + +**Files:** +- Modify: `nodes/__init__.py` (verify all imports work) +- Modify: `__init__.py` (verify top-level registration) + +**Step 1: Verify all node imports resolve** + +Run from ComfyUI's Python: +```bash +cd /path/to/ComfyUI +python -c " +import sys +sys.path.insert(0, 'custom_nodes/ComfyUI-PrismAudio') +from nodes import NODE_CLASS_MAPPINGS +print('Registered nodes:', list(NODE_CLASS_MAPPINGS.keys())) +" +``` + +Expected output: +``` +Registered nodes: ['PrismAudioModelLoader', 'PrismAudioFeatureLoader', 'PrismAudioFeatureExtractor', 'PrismAudioSampler', 'PrismAudioTextOnly'] +``` + +**Step 2: Fix any import errors iteratively** + +Common issues: +- `prismaudio_core` internal imports may reference wrong module paths +- Missing model submodules in `prismaudio_core/models/` +- flash-attn fallback not properly guarded + +**Step 3: Test model loading (requires GPU + model files)** + +```bash +python -c " +from prismaudio_core.factory import create_model_from_config +import json +with open('prismaudio_core/configs/prismaudio.json') as f: + config = json.load(f) +model = create_model_from_config(config) +print('Model created, params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M') +" +``` + +Expected: `Model created, params: ~518 M` + +**Step 4: End-to-end test with pre-computed features** + +If you have a `.npz` feature file from the PrismAudio repo's demo data, test the full pipeline in ComfyUI: +1. PrismAudioModelLoader -> PrismAudioFeatureLoader -> PrismAudioSampler -> Preview Audio node + +**Step 5: Verify variable duration handling** + +Test with multiple durations (5s, 10s, 20s) to ensure the model adapts to different +input shapes and produces audio of the expected length. + +**Step 6: Commit** + +```bash +git add -A +git commit -m "feat: integration fixes and verification" +``` + +--- + +### Task 11: README + +**Files:** +- Create: `README.md` + +**Step 1: Write README covering:** + +- What PrismAudio is (brief, link to paper) +- Installation (clone, pip install requirements, optional extraction env setup) +- Node descriptions with input/output tables +- Example workflows (quality path with FeatureExtractor, quick path with FeatureLoader, text-only) +- HuggingFace authentication (2 methods: `HF_TOKEN` env var, `huggingface-cli login`) + - Note: hf_token is NOT a node widget for security reasons + - Which models may be gated (T5-Gemma, potentially Stable Audio VAE) +- Model file sizes: diffusion ~2.7GB, VAE ~2.5GB, synchformer ~950MB +- Extraction env setup via conda environment.yml +- Troubleshooting (VRAM, flash-attn optional, gated models) +- Credits and license + +**Step 2: Commit** + +```bash +git add README.md +git commit -m "docs: README with installation and usage instructions" +``` + +--- + +## Dependency Graph + +``` +Task 1 (scaffolding) + ├── Task 2 (core config + factory) ──┐ + │ └── Task 3 (core models) ──────┤ + │ └── Task 4 (core sampling)┤ + │ ├── Task 5 (ModelLoader node) + │ ├── Task 6 (FeatureLoader node) + │ ├── Task 7 (FeatureExtractor node) + │ ├── Task 8 (Sampler node) + │ └── Task 9 (TextOnly node) + └────────────────────────────────────────── Task 10 (Integration) + └── Task 11 (README) +``` + +Tasks 5-9 can be parallelized after Task 4 is complete. Task 3 is the heaviest — it involves extracting and adapting ~10 model files.