1373 lines
49 KiB
Markdown
1373 lines
49 KiB
Markdown
# ComfyUI-PrismAudio Implementation Plan
|
|
|
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
|
|
**Goal:** Build ComfyUI custom nodes for PrismAudio video-to-audio and text-to-audio generation with adaptive VRAM management and isolated feature extraction.
|
|
|
|
**Architecture:** Selective code extraction from PrismAudio `prismaudio` branch into `prismaudio_core/` module. 5 ComfyUI nodes (ModelLoader, FeatureLoader, FeatureExtractor, Sampler, TextOnly). Feature extraction via subprocess bridge to isolated JAX/TF environment. Auto-download from HuggingFace with gated model support.
|
|
|
|
**Tech Stack:** PyTorch, ComfyUI APIs (folder_paths, comfy.model_management, comfy.utils), HuggingFace Hub, transformers (T5-Gemma), einops, k-diffusion, safetensors
|
|
|
|
---
|
|
|
|
## Bug Fixes Applied (from review)
|
|
|
|
This plan incorporates fixes for all 14 bugs identified during review:
|
|
|
|
1. **sample_discrete_euler callback**: Copy function into prismaudio_core, add callback param to the sampling loop
|
|
2. **Metadata format**: Return `(dict,)` tuple, not flat dict — matches `MultiConditioner.forward(batch_metadata: List[Dict])`
|
|
3. **video_exist**: Use `torch.tensor(True/False)`, not Python bool
|
|
4. **None features**: Use zero tensors of correct shape, never None — `pad_sequence(None)` crashes
|
|
5. **update_seq_lengths removed**: Does not exist in source. Model adapts to input shapes dynamically — no seq length config needed
|
|
6. **Sequence length config**: Not needed — model handles variable lengths natively via input tensor shapes
|
|
7. **T5-Gemma class**: Use `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder`
|
|
8. **Peak normalization**: Add `audio / audio.abs().max().clamp(min=1e-8)` before clamp
|
|
9. **Empty feature substitution**: Match reference approach — substitute on raw conditioning output with correct shapes
|
|
10. **hf_token security**: Remove STRING widget entirely. Rely on env var / `huggingface-cli login` only. Document in README
|
|
11. **Synchformer size**: Corrected to ~950MB in docs
|
|
12. **T5 truncation**: Match reference — `truncation=False`, no max_length
|
|
13. **Remove global_video/text_features from metadata**: Not consumed by any conditioner
|
|
14. **Add tqdm to requirements**
|
|
|
|
### Bug Fixes Applied (from second review)
|
|
|
|
15. **Sync_MLP zero-tensor crash**: Sync zero-tensor fallback must be `[8, 768]` not `[1, 768]` — Sync_MLP does `length // 8` which gives 0 for length=1, causing `F.interpolate` on empty tensor
|
|
16. **sample_discrete_euler undefined `i`**: Loop needs `enumerate()` — `for i, (t_curr, t_prev) in enumerate(zip(...))`
|
|
17. **_update_seq_lengths removed entirely**: Was a no-op (attributes don't exist on DiT). Model handles variable lengths natively — function deleted
|
|
18. **cot_description removed from Sampler**: Was dead code — features already contain pre-computed text_features
|
|
19. **Conditioner VRAM leak**: Add `diffusion.conditioner.to(get_offload_device())` after generation in offload path
|
|
20. **VAE size corrected**: ~2.52GB, not ~300MB
|
|
|
|
### Bug Fixes Applied (from third review)
|
|
|
|
21. **Remove video_features substitution**: `_substitute_empty_features` should only substitute sync_features. Reference code checks for `metaclip_features` (wrong key for prismaudio config), so video substitution never runs. Cond_MLP with zero input + bias-free linears naturally produces near-zero output
|
|
22. **Remove dead `sample()` and `sample_rf`**: Wrong noise schedule (linear vs cosine), never called for rectified_flow. Only keep `sample_discrete_euler`
|
|
23. **VAE decode in fp32**: Keep pretransform in fp32 even when rest of model is fp16/bf16 — snake activations overflow in fp16
|
|
24. **Lazy imports in nodes/__init__.py**: Use try/except to allow incremental development
|
|
25. **MPS Generator guard**: `torch.Generator(device="cpu")` on Apple Silicon, move noise to device after
|
|
26. **Use comfy.utils.load_torch_file for VAE**: Consistent with diffusion loading, handles PyTorch 2.6+ weights_only default
|
|
27. **Task 10 stale reference**: Remove mention of `_update_seq_lengths`
|
|
|
|
### Bug Fixes Applied (from fourth review)
|
|
|
|
28. **TextOnly missing MPS guard**: Fix-on-fix regression — MPS Generator guard was applied to Sampler but not TextOnly
|
|
29. **TextOnly noise dtype**: Was passing dtype to torch.randn directly (fp16 noise), now generates fp32 then converts (matching Sampler)
|
|
30. **Sync substitution seq length**: Low-severity divergence from reference, accepted (DiT handles variable-length sync_cond)
|
|
|
|
---
|
|
|
|
### Task 1: Project Scaffolding
|
|
|
|
**Files:**
|
|
- Create: `__init__.py`
|
|
- Create: `nodes/__init__.py`
|
|
- Create: `nodes/utils.py`
|
|
- Create: `requirements.txt`
|
|
|
|
**Step 1: Create requirements.txt**
|
|
|
|
```
|
|
einops>=0.7.0
|
|
safetensors
|
|
huggingface_hub
|
|
transformers>=4.52.3
|
|
k-diffusion>=0.1.1
|
|
alias-free-torch
|
|
descript-audio-codec
|
|
tqdm
|
|
```
|
|
|
|
**Step 2: Create nodes/utils.py with shared helpers**
|
|
|
|
```python
|
|
import os
|
|
import torch
|
|
import folder_paths
|
|
import comfy.model_management as mm
|
|
|
|
PRISMAUDIO_CATEGORY = "PrismAudio"
|
|
SAMPLE_RATE = 44100
|
|
DOWNSAMPLING_RATIO = 2048
|
|
IO_CHANNELS = 64
|
|
|
|
def get_prismaudio_model_dir():
|
|
"""Get or create the prismaudio model directory."""
|
|
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
return model_dir
|
|
|
|
def register_model_folder():
|
|
"""Register prismaudio model folder with ComfyUI."""
|
|
model_dir = get_prismaudio_model_dir()
|
|
folder_paths.add_model_folder_path("prismaudio", model_dir)
|
|
|
|
def get_device():
|
|
return mm.get_torch_device()
|
|
|
|
def get_offload_device():
|
|
return mm.unet_offload_device()
|
|
|
|
def get_free_memory(device=None):
|
|
if device is None:
|
|
device = get_device()
|
|
return mm.get_free_memory(device)
|
|
|
|
def soft_empty_cache():
|
|
mm.soft_empty_cache()
|
|
|
|
def determine_precision(preference, device):
|
|
"""Determine the best precision for the given device."""
|
|
if preference != "auto":
|
|
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
|
|
if device.type == "cpu":
|
|
return torch.float32
|
|
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
return torch.bfloat16
|
|
return torch.float16
|
|
|
|
def determine_offload_strategy(preference):
|
|
"""Determine offload strategy based on available VRAM."""
|
|
if preference != "auto":
|
|
return preference
|
|
free_mem = get_free_memory()
|
|
gb = free_mem / (1024 ** 3)
|
|
if gb >= 24:
|
|
return "keep_in_vram"
|
|
else:
|
|
return "offload_to_cpu"
|
|
|
|
def try_import_flash_attn():
|
|
"""Try to import flash attention, return None if unavailable."""
|
|
try:
|
|
import flash_attn
|
|
return flash_attn
|
|
except ImportError:
|
|
return None
|
|
|
|
def resolve_hf_token():
|
|
"""Resolve HF token from env var or cached login. No widget — security risk."""
|
|
env_token = os.environ.get("HF_TOKEN")
|
|
if env_token:
|
|
return env_token
|
|
# huggingface_hub will use cached token automatically if None is passed
|
|
return None
|
|
```
|
|
|
|
**Step 3: Create nodes/__init__.py**
|
|
|
|
```python
|
|
NODE_CLASS_MAPPINGS = {}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
|
|
|
# Lazy imports — allows incremental development (nodes can be added one at a time)
|
|
_NODES = {
|
|
"PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"),
|
|
"PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"),
|
|
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
|
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
|
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
|
}
|
|
|
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
|
try:
|
|
import importlib
|
|
mod = importlib.import_module(module_path, package=__name__)
|
|
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
|
|
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
|
|
except (ImportError, AttributeError) as e:
|
|
print(f"[PrismAudio] Skipping {key}: {e}")
|
|
```
|
|
|
|
**Step 4: Create top-level __init__.py**
|
|
|
|
```python
|
|
"""
|
|
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
|
|
"""
|
|
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
|
|
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
|
```
|
|
|
|
**Step 5: Commit**
|
|
|
|
```bash
|
|
git add __init__.py nodes/__init__.py nodes/utils.py requirements.txt
|
|
git commit -m "feat: project scaffolding with shared utils and node registration"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 2: Extract prismaudio_core — Model Config + Factory
|
|
|
|
**Files:**
|
|
- Create: `prismaudio_core/__init__.py`
|
|
- Create: `prismaudio_core/configs/prismaudio.json` (copy from PrismAudio repo)
|
|
- Create: `prismaudio_core/factory.py` (adapted from `PrismAudio/models/factory.py`)
|
|
|
|
**Step 1: Create prismaudio_core/__init__.py**
|
|
|
|
```python
|
|
"""
|
|
PrismAudio core inference modules.
|
|
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
|
|
Only inference-critical code — no training, no JAX/TF dependencies.
|
|
"""
|
|
```
|
|
|
|
**Step 2: Copy prismaudio.json config**
|
|
|
|
Fetch from `https://raw.githubusercontent.com/FunAudioLLM/ThinkSound/prismaudio/PrismAudio/configs/model_configs/prismaudio.json` and save to `prismaudio_core/configs/prismaudio.json`. This is a JSON config file with no code — copy verbatim.
|
|
|
|
**Step 3: Create factory.py**
|
|
|
|
Extract from `PrismAudio/models/factory.py`. Keep only these functions (remove training-related code):
|
|
- `create_model_from_config(model_config)` — entry point
|
|
- `create_diffusion_cond_from_config(config)` — creates the full model
|
|
- `create_pretransform_from_config(pretransform_config, sample_rate)` — VAE
|
|
- `create_autoencoder_from_config(config)` — AudioAutoencoder
|
|
- `create_bottleneck_from_config(config)` — VAEBottleneck
|
|
- `create_multi_conditioner_from_conditioning_config(config)` — conditioners
|
|
|
|
All imports should reference `prismaudio_core.models.*` instead of `PrismAudio.models.*`.
|
|
|
|
**Step 4: Commit**
|
|
|
|
```bash
|
|
git add prismaudio_core/
|
|
git commit -m "feat: extract prismaudio_core config and model factory"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 3: Extract prismaudio_core — Model Modules
|
|
|
|
**Files:**
|
|
- Create: `prismaudio_core/models/__init__.py`
|
|
- Create: `prismaudio_core/models/dit.py` (from `PrismAudio/models/dit.py`)
|
|
- Create: `prismaudio_core/models/diffusion.py` (from `PrismAudio/models/diffusion.py`)
|
|
- Create: `prismaudio_core/models/conditioners.py` (from `PrismAudio/models/conditioners.py`)
|
|
- Create: `prismaudio_core/models/autoencoders.py` (from `PrismAudio/models/autoencoders.py`)
|
|
- Create: `prismaudio_core/models/pretransforms.py` (from `PrismAudio/models/pretransforms.py`)
|
|
- Create: `prismaudio_core/models/blocks.py` (from `PrismAudio/models/blocks.py`)
|
|
- Create: `prismaudio_core/models/utils.py` (from `PrismAudio/models/utils.py`)
|
|
- Create: `prismaudio_core/models/bottleneck.py` (from `PrismAudio/models/bottleneck.py`)
|
|
- Create: `prismaudio_core/models/transformer.py` (from `PrismAudio/models/transformer.py`)
|
|
- Create: `prismaudio_core/models/local_attention.py` (if used by transformer)
|
|
|
|
**Step 1: Extract model files**
|
|
|
|
For each file, extract from the PrismAudio repo. Key modifications:
|
|
- Change all internal imports from `PrismAudio.models.*` to `prismaudio_core.models.*`
|
|
- Remove training-only code (loss functions, training step methods, gradient checkpointing setup)
|
|
- Keep all inference paths intact
|
|
|
|
**Critical classes to preserve:**
|
|
|
|
From `dit.py`:
|
|
- `DiffusionTransformer` — full class with `forward()`, CFG logic, conditioning assembly
|
|
- `FourierFeatures` — timestep embedding
|
|
- Keep `empty_clip_feat` and `empty_sync_feat` learned parameters (nn.Parameter, zeros)
|
|
|
|
From `diffusion.py`:
|
|
- `ConditionedDiffusionModelWrapper` — with `get_conditioning_inputs()` and routing logic
|
|
- `DiTWrapper` — thin wrapper that passes all kwargs through
|
|
- `create_diffusion_cond_from_config()` — factory function
|
|
|
|
From `conditioners.py`:
|
|
- `Cond_MLP` (type `"cond_mlp"`) — for video_features and text_features. Uses `pad_sequence`, 2-layer MLP, returns `[embeddings, ones_mask]`. During eval with batch<16, doubles batch with null embed for CFG
|
|
- `Sync_MLP` (type `"sync_mlp"`) — for sync_features with learnable `sync_pos_emb` of shape (1,1,8,dim), reshapes into segments of 8, interpolates to target length
|
|
- `MultiConditioner` — iterates over `batch_metadata: List[Dict]`, collects per-sample inputs, calls each conditioner. Returns dict of `{key: (tensor, mask)}`
|
|
- `create_multi_conditioner_from_conditioning_config()` — factory
|
|
|
|
From `autoencoders.py`:
|
|
- `AudioAutoencoder` — with `encode_audio()` and `decode_audio()`
|
|
- `OobleckEncoder`, `OobleckDecoder` — with ResidualUnit, snake activation
|
|
- Dependencies: `alias_free_torch` (SnakeBeta), `dac.nn` (WNConv1d, WNConvTranspose1d)
|
|
|
|
From `pretransforms.py`:
|
|
- `AutoencoderPretransform` — wraps AudioAutoencoder, `encode()` and `decode()` methods
|
|
|
|
From `bottleneck.py`:
|
|
- `VAEBottleneck` — reparameterization trick (split mean/logvar, sample)
|
|
|
|
From `blocks.py`:
|
|
- Any shared blocks used by the above (attention blocks, FeedForward, etc.)
|
|
|
|
From `transformer.py`:
|
|
- `ContinuousTransformer` — the core transformer with cross-attention, used by DiffusionTransformer
|
|
|
|
From `utils.py`:
|
|
- `load_ckpt_state_dict()` — handles .safetensors and .ckpt, optional prefix stripping
|
|
- `remove_weight_norm_from_model()` — used in some inference paths
|
|
|
|
**Step 2: Handle flash-attn gracefully in transformer.py / blocks.py**
|
|
|
|
Replace hard `import flash_attn` with:
|
|
```python
|
|
try:
|
|
from flash_attn import flash_attn_func
|
|
HAS_FLASH_ATTN = True
|
|
except ImportError:
|
|
HAS_FLASH_ATTN = False
|
|
```
|
|
|
|
In the attention forward pass, use:
|
|
```python
|
|
if HAS_FLASH_ATTN:
|
|
out = flash_attn_func(q, k, v, ...)
|
|
else:
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, ...)
|
|
```
|
|
|
|
**Step 3: Verify imports resolve**
|
|
|
|
Run: `python -c "from prismaudio_core.factory import create_model_from_config; print('OK')"` from the project root (with ComfyUI's python).
|
|
|
|
Expected: `OK` (or import errors to fix iteratively)
|
|
|
|
**Step 4: Commit**
|
|
|
|
```bash
|
|
git add prismaudio_core/models/
|
|
git commit -m "feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 4: Extract prismaudio_core — Inference/Sampling (with callback fix)
|
|
|
|
**Files:**
|
|
- Create: `prismaudio_core/inference/__init__.py`
|
|
- Create: `prismaudio_core/inference/sampling.py` (MODIFIED from `PrismAudio/inference/sampling.py`)
|
|
- Create: `prismaudio_core/inference/utils.py` (from `PrismAudio/inference/utils.py`)
|
|
|
|
**Step 1: Extract sampling.py WITH callback support added**
|
|
|
|
The original `sample_discrete_euler` uses `tqdm` and has no callback parameter.
|
|
We MUST copy and modify it to add callback support for ComfyUI progress bars.
|
|
|
|
```python
|
|
import torch
|
|
from tqdm import trange
|
|
|
|
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
|
"""Discrete Euler sampler for rectified flow, with optional callback.
|
|
|
|
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
|
Original uses tqdm internally.
|
|
|
|
Args:
|
|
model: The diffusion model (DiTWrapper)
|
|
x: Initial noise tensor [B, C, T]
|
|
steps: Number of sampling steps
|
|
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
|
callback: Optional callable({"i": step, "x": current_x}) for progress
|
|
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
|
sync_cond, cfg_scale, batch_cfg, etc.
|
|
"""
|
|
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
|
|
|
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
|
dt = t_next - t_curr
|
|
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
|
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
|
if callback is not None:
|
|
callback({"i": i, "x": x})
|
|
|
|
return x
|
|
|
|
|
|
|
|
# Note: sample_rf() and sample() (v-diffusion) are NOT included.
|
|
# PrismAudio uses rectified_flow objective which only needs sample_discrete_euler.
|
|
# Including unused samplers with potentially wrong math is a maintenance hazard.
|
|
```
|
|
|
|
**Step 2: Extract inference/utils.py**
|
|
|
|
Keep:
|
|
- `set_audio_channels(audio, target_channels)`
|
|
- `prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device)`
|
|
|
|
**Step 3: Verify sampling import**
|
|
|
|
Run: `python -c "from prismaudio_core.inference.sampling import sample_discrete_euler; print('OK')"`
|
|
|
|
Expected: `OK`
|
|
|
|
**Step 4: Commit**
|
|
|
|
```bash
|
|
git add prismaudio_core/inference/
|
|
git commit -m "feat: extract prismaudio_core inference with callback-enabled sampling"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 5: PrismAudioModelLoader Node
|
|
|
|
**Files:**
|
|
- Create: `nodes/model_loader.py`
|
|
|
|
**Step 1: Write the node**
|
|
|
|
Key design decisions:
|
|
- No hf_token widget (security risk — saved to workflow JSON). Uses env var / cached login only.
|
|
- Creates model with default config. Duration-dependent seq lengths handled at sample time.
|
|
- The model config's `sample_size: 397312` corresponds to ~9s default. For other durations,
|
|
the Sampler node will update seq lengths on the DiT before each generation.
|
|
|
|
```python
|
|
import os
|
|
import json
|
|
import torch
|
|
import folder_paths
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
|
|
from .utils import (
|
|
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
|
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
|
soft_empty_cache, resolve_hf_token,
|
|
)
|
|
|
|
# HuggingFace repo for auto-download
|
|
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
|
REQUIRED_FILES = {
|
|
"diffusion": "prismaudio.ckpt",
|
|
"vae": "vae.ckpt",
|
|
"synchformer": "synchformer_state_dict.pth",
|
|
}
|
|
|
|
|
|
def _download_if_missing(filename, model_dir, hf_token=None):
|
|
"""Download a model file from HuggingFace if not present locally."""
|
|
filepath = os.path.join(model_dir, filename)
|
|
if os.path.exists(filepath):
|
|
return filepath
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
|
try:
|
|
downloaded = hf_hub_download(
|
|
repo_id=HF_REPO_ID,
|
|
filename=filename,
|
|
local_dir=model_dir,
|
|
token=hf_token or None,
|
|
)
|
|
return downloaded
|
|
except Exception as e:
|
|
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
|
raise RuntimeError(
|
|
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
|
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
|
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
|
) from e
|
|
raise
|
|
|
|
|
|
class PrismAudioModelLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
register_model_folder()
|
|
return {
|
|
"required": {
|
|
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
|
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
|
RETURN_NAMES = ("model",)
|
|
FUNCTION = "load_model"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def load_model(self, precision, offload_strategy):
|
|
device = get_device()
|
|
dtype = determine_precision(precision, device)
|
|
strategy = determine_offload_strategy(offload_strategy)
|
|
token = resolve_hf_token()
|
|
model_dir = get_prismaudio_model_dir()
|
|
|
|
# Auto-download missing files
|
|
for key, filename in REQUIRED_FILES.items():
|
|
_download_if_missing(filename, model_dir, hf_token=token)
|
|
|
|
# Load config
|
|
config_path = os.path.join(
|
|
os.path.dirname(os.path.dirname(__file__)),
|
|
"prismaudio_core", "configs", "prismaudio.json"
|
|
)
|
|
with open(config_path) as f:
|
|
model_config = json.load(f)
|
|
|
|
# Create model from config
|
|
from prismaudio_core.factory import create_model_from_config
|
|
model = create_model_from_config(model_config)
|
|
|
|
# Load diffusion weights
|
|
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
|
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
|
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
|
if "state_dict" in diffusion_state:
|
|
diffusion_state = diffusion_state["state_dict"]
|
|
model.load_state_dict(diffusion_state, strict=False)
|
|
|
|
# Load VAE weights separately
|
|
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
|
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
|
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
|
# Strip "autoencoder." prefix from keys
|
|
vae_state = {}
|
|
prefix = "autoencoder."
|
|
for k, v in vae_full_state.items():
|
|
if k.startswith(prefix):
|
|
vae_state[k[len(prefix):]] = v
|
|
else:
|
|
vae_state[k] = v
|
|
model.pretransform.load_state_dict(vae_state)
|
|
|
|
# Apply precision: DiT + conditioners in user-selected dtype,
|
|
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
|
model.model.to(dtype) # DiTWrapper
|
|
model.conditioner.to(dtype) # MultiConditioner
|
|
# model.pretransform stays in fp32
|
|
|
|
if strategy == "keep_in_vram":
|
|
model = model.to(device)
|
|
else:
|
|
model = model.to(get_offload_device())
|
|
|
|
model.eval()
|
|
|
|
return ({
|
|
"model": model,
|
|
"dtype": dtype,
|
|
"strategy": strategy,
|
|
"config": model_config,
|
|
"model_dir": model_dir,
|
|
},)
|
|
```
|
|
|
|
**Step 2: Test that ComfyUI discovers the node**
|
|
|
|
Run ComfyUI and check that "PrismAudio Model Loader" appears in the node list.
|
|
|
|
**Step 3: Commit**
|
|
|
|
```bash
|
|
git add nodes/model_loader.py
|
|
git commit -m "feat: PrismAudioModelLoader node with auto-download and adaptive VRAM"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 6: PrismAudioFeatureLoader Node
|
|
|
|
**Files:**
|
|
- Create: `nodes/feature_loader.py`
|
|
|
|
**Step 1: Write the node**
|
|
|
|
```python
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
from .utils import PRISMAUDIO_CATEGORY
|
|
|
|
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
|
# global_video_features and global_text_features are NOT consumed by any conditioner
|
|
# in the prismaudio.json config — they are unused.
|
|
REQUIRED_KEYS = [
|
|
"video_features",
|
|
"text_features",
|
|
"sync_features",
|
|
]
|
|
|
|
|
|
class PrismAudioFeatureLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
|
RETURN_NAMES = ("features",)
|
|
FUNCTION = "load_features"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def load_features(self, npz_path):
|
|
if not os.path.exists(npz_path):
|
|
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
|
|
|
data = np.load(npz_path, allow_pickle=True)
|
|
|
|
features = {}
|
|
for key in REQUIRED_KEYS:
|
|
if key in data:
|
|
features[key] = torch.from_numpy(data[key]).float()
|
|
else:
|
|
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
|
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
|
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
|
if key == "sync_features":
|
|
features[key] = torch.zeros(8, 768)
|
|
else:
|
|
features[key] = torch.zeros(1, 1024)
|
|
|
|
# Load duration if present
|
|
if "duration" in data:
|
|
features["duration"] = float(data["duration"])
|
|
|
|
return (features,)
|
|
```
|
|
|
|
**Step 2: Commit**
|
|
|
|
```bash
|
|
git add nodes/feature_loader.py
|
|
git commit -m "feat: PrismAudioFeatureLoader node for pre-computed .npz files"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 7: PrismAudioFeatureExtractor Node (Subprocess Bridge)
|
|
|
|
**Files:**
|
|
- Create: `nodes/feature_extractor.py`
|
|
- Create: `scripts/extract_features.py`
|
|
- Create: `scripts/environment.yml`
|
|
|
|
**Step 1: Create the conda environment.yml**
|
|
|
|
```yaml
|
|
name: prismaudio-extract
|
|
channels:
|
|
- conda-forge
|
|
- defaults
|
|
dependencies:
|
|
- python=3.10
|
|
- pip
|
|
- ffmpeg<7
|
|
- pip:
|
|
- torch>=2.6.0
|
|
- torchaudio>=2.6.0
|
|
- torchvision>=0.21.0
|
|
- tensorflow-cpu==2.15.0
|
|
- jax
|
|
- jaxlib
|
|
- transformers>=4.52.3
|
|
- decord
|
|
- einops>=0.7.0
|
|
- numpy
|
|
- mediapy
|
|
- git+https://github.com/google-deepmind/videoprism.git
|
|
```
|
|
|
|
**Step 2: Create scripts/extract_features.py**
|
|
|
|
This is a standalone script that:
|
|
1. Takes `--video`, `--cot_text`, `--output` arguments
|
|
2. Loads VideoPrism, T5-Gemma, Synchformer
|
|
3. Extracts features from the video
|
|
4. Saves as `.npz`
|
|
|
|
```python
|
|
#!/usr/bin/env python3
|
|
"""
|
|
Standalone PrismAudio feature extraction script.
|
|
Run in a separate conda env with JAX/TF installed.
|
|
|
|
Usage:
|
|
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
|
|
|
|
Setup:
|
|
conda env create -f environment.yml
|
|
conda activate prismaudio-extract
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
|
|
parser.add_argument("--video", required=True, help="Path to input video")
|
|
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
|
|
parser.add_argument("--output", required=True, help="Output .npz path")
|
|
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
|
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
|
parser.add_argument("--clip_fps", type=float, default=4.0)
|
|
parser.add_argument("--clip_size", type=int, default=288)
|
|
parser.add_argument("--sync_fps", type=float, default=25.0)
|
|
parser.add_argument("--sync_size", type=int, default=224)
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.video):
|
|
print(f"Error: Video not found: {args.video}")
|
|
sys.exit(1)
|
|
|
|
# Import feature extraction utils (requires JAX/TF)
|
|
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
|
import torchvision.transforms as T
|
|
from decord import VideoReader, cpu
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Initialize feature extractor
|
|
feat_utils = FeaturesUtils(
|
|
vae_config_path=args.vae_config,
|
|
synchformer_ckpt=args.synchformer_ckpt,
|
|
device=device,
|
|
)
|
|
|
|
# Load and preprocess video
|
|
vr = VideoReader(args.video, ctx=cpu(0))
|
|
fps = vr.get_avg_fps()
|
|
total_frames = len(vr)
|
|
duration = total_frames / fps
|
|
|
|
# Extract CLIP frames (4fps, 288x288)
|
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
|
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
|
|
|
clip_transform = T.Compose([
|
|
T.ToPILImage(),
|
|
T.Resize(args.clip_size),
|
|
T.CenterCrop(args.clip_size),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
])
|
|
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
|
|
|
# Extract Sync frames (25fps, 224x224)
|
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
|
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
|
|
|
sync_transform = T.Compose([
|
|
T.ToPILImage(),
|
|
T.Resize(args.sync_size),
|
|
T.CenterCrop(args.sync_size),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
])
|
|
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
|
|
|
|
# Extract features
|
|
print("[PrismAudio] Encoding text with T5-Gemma...")
|
|
text_features = feat_utils.encode_t5_text([args.cot_text])
|
|
|
|
print("[PrismAudio] Encoding video with VideoPrism...")
|
|
global_video_features, video_features, global_text_features = \
|
|
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
|
|
|
|
print("[PrismAudio] Encoding video with Synchformer...")
|
|
sync_features = feat_utils.encode_video_with_sync(sync_input)
|
|
|
|
# Save as .npz
|
|
np.savez(
|
|
args.output,
|
|
video_features=video_features.cpu().numpy(),
|
|
global_video_features=global_video_features.cpu().numpy(),
|
|
text_features=text_features.cpu().numpy(),
|
|
global_text_features=global_text_features.cpu().numpy(),
|
|
sync_features=sync_features.cpu().numpy(),
|
|
caption_cot=args.cot_text,
|
|
duration=duration,
|
|
)
|
|
print(f"[PrismAudio] Features saved to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
```
|
|
|
|
**Step 3: Create the feature extractor node**
|
|
|
|
```python
|
|
import os
|
|
import hashlib
|
|
import subprocess
|
|
import tempfile
|
|
import torch
|
|
|
|
from .utils import PRISMAUDIO_CATEGORY
|
|
from .feature_loader import PrismAudioFeatureLoader
|
|
|
|
|
|
def _hash_inputs(video_tensor, cot_text):
|
|
"""Create a hash of the inputs for caching."""
|
|
h = hashlib.sha256()
|
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
|
h.update(cot_text.encode())
|
|
return h.hexdigest()[:16]
|
|
|
|
|
|
def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30):
|
|
"""Save ComfyUI IMAGE tensor [T,H,W,C] to MP4."""
|
|
import torchvision.io as tvio
|
|
# ComfyUI IMAGE is [T,H,W,C] float32 [0,1]
|
|
frames = (video_tensor * 255).to(torch.uint8)
|
|
# torchvision write_video expects [T,H,W,C] uint8
|
|
tvio.write_video(output_path, frames, fps=fps)
|
|
|
|
|
|
class PrismAudioFeatureExtractor:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"video": ("IMAGE",),
|
|
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
|
},
|
|
"optional": {
|
|
"python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}),
|
|
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
|
"synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
|
RETURN_NAMES = ("features",)
|
|
FUNCTION = "extract_features"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""):
|
|
# Determine cache directory
|
|
if not cache_dir:
|
|
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
# Check cache
|
|
cache_hash = _hash_inputs(video, caption_cot)
|
|
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
|
if os.path.exists(cached_path):
|
|
print(f"[PrismAudio] Using cached features: {cached_path}")
|
|
loader = PrismAudioFeatureLoader()
|
|
return loader.load_features(cached_path)
|
|
|
|
# Save video to temp file
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
|
tmp_video = tmp.name
|
|
_save_video_tensor_to_mp4(video, tmp_video)
|
|
|
|
# Build subprocess command
|
|
script_path = os.path.join(
|
|
os.path.dirname(os.path.dirname(__file__)),
|
|
"scripts", "extract_features.py"
|
|
)
|
|
|
|
cmd = [
|
|
python_env,
|
|
script_path,
|
|
"--video", tmp_video,
|
|
"--cot_text", caption_cot,
|
|
"--output", cached_path,
|
|
]
|
|
if synchformer_ckpt:
|
|
cmd.extend(["--synchformer_ckpt", synchformer_ckpt])
|
|
|
|
print(f"[PrismAudio] Extracting features via subprocess...")
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=600, # 10 minute timeout
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(
|
|
f"[PrismAudio] Feature extraction failed:\n{result.stderr}"
|
|
)
|
|
print(result.stdout)
|
|
finally:
|
|
if os.path.exists(tmp_video):
|
|
os.unlink(tmp_video)
|
|
|
|
# Load the extracted features
|
|
loader = PrismAudioFeatureLoader()
|
|
return loader.load_features(cached_path)
|
|
```
|
|
|
|
**Step 4: Commit**
|
|
|
|
```bash
|
|
git add nodes/feature_extractor.py scripts/extract_features.py scripts/environment.yml
|
|
git commit -m "feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 8: PrismAudioSampler Node
|
|
|
|
**Files:**
|
|
- Create: `nodes/sampler.py`
|
|
|
|
**Step 1: Write the sampler node**
|
|
|
|
This is the core node. Key fixes from review:
|
|
- Metadata is a TUPLE of dicts, not a flat dict
|
|
- video_exist is torch.tensor, not Python bool
|
|
- Empty features are zero tensors, not None
|
|
- Peak normalization before clamp
|
|
- Sequence lengths set on DiT config before sampling (matching predict.py approach)
|
|
- No callback kwarg forwarded to model — callback is handled by our modified sample_discrete_euler
|
|
|
|
```python
|
|
import torch
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
|
|
from .utils import (
|
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
|
get_device, get_offload_device, soft_empty_cache,
|
|
)
|
|
|
|
|
|
class PrismAudioSampler:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("PRISMAUDIO_MODEL",),
|
|
"features": ("PRISMAUDIO_FEATURES",),
|
|
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds"}),
|
|
"steps": ("INT", {"default": 24, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
|
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("AUDIO",)
|
|
RETURN_NAMES = ("audio",)
|
|
FUNCTION = "generate"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
|
device = get_device()
|
|
dtype = model["dtype"]
|
|
strategy = model["strategy"]
|
|
diffusion = model["model"]
|
|
|
|
# Compute latent dimensions
|
|
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
|
|
|
# Note: no seq length config needed — the model adapts to input tensor shapes
|
|
# dynamically via its transformer architecture.
|
|
|
|
# Determine if video features are present (not all zeros)
|
|
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
|
|
|
# Build metadata as a TUPLE of dicts (one per batch sample)
|
|
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
|
sample_meta = {
|
|
"video_features": features["video_features"].to(device, dtype=dtype),
|
|
"text_features": features["text_features"].to(device, dtype=dtype),
|
|
"sync_features": features["sync_features"].to(device, dtype=dtype),
|
|
"video_exist": torch.tensor(has_video),
|
|
}
|
|
metadata = (sample_meta,)
|
|
|
|
# Move model to device if offloaded
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.model.to(device)
|
|
diffusion.conditioner.to(device)
|
|
soft_empty_cache()
|
|
|
|
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
|
# Run conditioning
|
|
conditioning = diffusion.conditioner(metadata, device)
|
|
|
|
# Handle missing video: substitute learned empty embeddings
|
|
if not has_video:
|
|
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
|
|
|
# Assemble conditioning inputs for the DiT
|
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
|
|
|
# Generate noise from seed (MPS doesn't support torch.Generator)
|
|
gen_device = "cpu" if device.type == "mps" else device
|
|
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
|
noise = torch.randn(
|
|
[1, IO_CHANNELS, latent_length],
|
|
generator=generator,
|
|
device=gen_device,
|
|
).to(device=device, dtype=dtype)
|
|
|
|
# Sample with progress bar
|
|
pbar = comfy.utils.ProgressBar(steps)
|
|
|
|
from prismaudio_core.inference.sampling import sample_discrete_euler
|
|
|
|
def on_step(info):
|
|
pbar.update(1)
|
|
|
|
fakes = sample_discrete_euler(
|
|
diffusion.model,
|
|
noise,
|
|
steps,
|
|
callback=on_step,
|
|
**cond_inputs,
|
|
cfg_scale=cfg_scale,
|
|
batch_cfg=True,
|
|
)
|
|
|
|
# Offload diffusion model and conditioner before VAE decode
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.model.to(get_offload_device())
|
|
diffusion.conditioner.to(get_offload_device())
|
|
soft_empty_cache()
|
|
diffusion.pretransform.to(device)
|
|
|
|
# VAE decode in fp32 (snake activations overflow in fp16)
|
|
with torch.amp.autocast(device_type=device.type, enabled=False):
|
|
audio = diffusion.pretransform.decode(fakes.float())
|
|
|
|
# Offload VAE
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.pretransform.to(get_offload_device())
|
|
soft_empty_cache()
|
|
|
|
# Peak normalize then clamp (matching reference: div by max abs before clamp)
|
|
audio = audio.float()
|
|
peak = audio.abs().max().clamp(min=1e-8)
|
|
audio = (audio / peak).clamp(-1, 1)
|
|
|
|
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
|
|
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
|
|
|
|
|
def _substitute_empty_features(diffusion, conditioning, device, dtype):
|
|
"""Replace sync conditioning with learned empty embedding when video is absent.
|
|
|
|
Only substitutes sync_features — NOT video_features. The reference code
|
|
(predict.py/app.py) checks for 'metaclip_features' which doesn't exist in the
|
|
prismaudio.json config, so video substitution never runs. Cond_MLP with zero
|
|
input + bias-free linear layers naturally produces near-zero output.
|
|
|
|
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
|
|
"""
|
|
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
|
|
|
|
# Only substitute sync_features (matching reference behavior for prismaudio config)
|
|
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
|
|
empty = dit.empty_sync_feat.to(device, dtype=dtype)
|
|
cond_tensor = conditioning['sync_features'][0]
|
|
batch_size = cond_tensor.shape[0]
|
|
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1)
|
|
conditioning['sync_features'][0] = empty_expanded
|
|
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
|
|
```
|
|
|
|
**Step 2: Verify the node registers**
|
|
|
|
Start ComfyUI, check "PrismAudio Sampler" appears in add-node menu.
|
|
|
|
**Step 3: Commit**
|
|
|
|
```bash
|
|
git add nodes/sampler.py
|
|
git commit -m "feat: PrismAudioSampler node with correct metadata format and peak normalization"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 9: PrismAudioTextOnly Node
|
|
|
|
**Files:**
|
|
- Create: `nodes/text_only.py`
|
|
|
|
**Step 1: Write the text-only node**
|
|
|
|
Key fixes from review:
|
|
- Uses `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder`
|
|
- No truncation (matching reference)
|
|
- Metadata is tuple of dicts with torch.tensor(False) for video_exist
|
|
- Zero tensors for video/sync features, not None
|
|
- Peak normalization
|
|
|
|
```python
|
|
import torch
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
|
|
from .utils import (
|
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
|
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
|
)
|
|
from .sampler import _substitute_empty_features
|
|
|
|
|
|
class PrismAudioTextOnly:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("PRISMAUDIO_MODEL",),
|
|
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}),
|
|
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
|
"steps": ("INT", {"default": 24, "min": 1, "max": 100}),
|
|
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("AUDIO",)
|
|
RETURN_NAMES = ("audio",)
|
|
FUNCTION = "generate"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
|
device = get_device()
|
|
dtype = model["dtype"]
|
|
strategy = model["strategy"]
|
|
diffusion = model["model"]
|
|
|
|
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
|
|
|
# Encode text with T5-Gemma
|
|
text_features = _encode_text_t5(text_prompt, device, dtype)
|
|
|
|
# Build metadata: tuple of one dict per sample
|
|
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
|
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
|
# These will be substituted with learned empty embeddings after conditioning
|
|
sample_meta = {
|
|
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
|
"text_features": text_features.to(device, dtype=dtype),
|
|
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
|
"video_exist": torch.tensor(False),
|
|
}
|
|
metadata = (sample_meta,)
|
|
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.model.to(device)
|
|
diffusion.conditioner.to(device)
|
|
soft_empty_cache()
|
|
|
|
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
|
conditioning = diffusion.conditioner(metadata, device)
|
|
|
|
# Substitute empty features for video/sync
|
|
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
|
|
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
|
|
|
# Generate noise from seed (MPS doesn't support torch.Generator)
|
|
gen_device = "cpu" if device.type == "mps" else device
|
|
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
|
noise = torch.randn(
|
|
[1, IO_CHANNELS, latent_length],
|
|
generator=generator,
|
|
device=gen_device,
|
|
).to(device=device, dtype=dtype)
|
|
|
|
pbar = comfy.utils.ProgressBar(steps)
|
|
|
|
from prismaudio_core.inference.sampling import sample_discrete_euler
|
|
|
|
def on_step(info):
|
|
pbar.update(1)
|
|
|
|
fakes = sample_discrete_euler(
|
|
diffusion.model,
|
|
noise,
|
|
steps,
|
|
callback=on_step,
|
|
**cond_inputs,
|
|
cfg_scale=cfg_scale,
|
|
batch_cfg=True,
|
|
)
|
|
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.model.to(get_offload_device())
|
|
diffusion.conditioner.to(get_offload_device())
|
|
soft_empty_cache()
|
|
diffusion.pretransform.to(device)
|
|
|
|
# VAE decode in fp32 (snake activations overflow in fp16)
|
|
with torch.amp.autocast(device_type=device.type, enabled=False):
|
|
audio = diffusion.pretransform.decode(fakes.float())
|
|
|
|
if strategy == "offload_to_cpu":
|
|
diffusion.pretransform.to(get_offload_device())
|
|
soft_empty_cache()
|
|
|
|
# Peak normalize then clamp
|
|
audio = audio.float()
|
|
peak = audio.abs().max().clamp(min=1e-8)
|
|
audio = (audio / peak).clamp(-1, 1)
|
|
|
|
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
|
|
|
|
|
# T5-Gemma encoder singleton
|
|
_t5_model = None
|
|
_t5_tokenizer = None
|
|
|
|
|
|
def _encode_text_t5(text, device, dtype):
|
|
"""Encode text using T5-Gemma.
|
|
|
|
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
|
FeaturesUtils.encode_t5_text() implementation.
|
|
No truncation applied (matching reference behavior).
|
|
"""
|
|
global _t5_model, _t5_tokenizer
|
|
|
|
if _t5_model is None:
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
model_id = "google/t5gemma-l-l-ul2-it"
|
|
token = resolve_hf_token()
|
|
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
|
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
|
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
|
_t5_model.eval()
|
|
|
|
_t5_model.to(device, dtype=dtype)
|
|
|
|
tokens = _t5_tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
).to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = _t5_model(**tokens)
|
|
|
|
# Move T5 off GPU after encoding to save VRAM
|
|
_t5_model.to("cpu")
|
|
soft_empty_cache()
|
|
|
|
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|
|
```
|
|
|
|
**Step 2: Commit**
|
|
|
|
```bash
|
|
git add nodes/text_only.py
|
|
git commit -m "feat: PrismAudioTextOnly node with correct T5-Gemma encoding"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 10: Integration Testing & Polish
|
|
|
|
**Files:**
|
|
- Modify: `nodes/__init__.py` (verify all imports work)
|
|
- Modify: `__init__.py` (verify top-level registration)
|
|
|
|
**Step 1: Verify all node imports resolve**
|
|
|
|
Run from ComfyUI's Python:
|
|
```bash
|
|
cd /path/to/ComfyUI
|
|
python -c "
|
|
import sys
|
|
sys.path.insert(0, 'custom_nodes/ComfyUI-PrismAudio')
|
|
from nodes import NODE_CLASS_MAPPINGS
|
|
print('Registered nodes:', list(NODE_CLASS_MAPPINGS.keys()))
|
|
"
|
|
```
|
|
|
|
Expected output:
|
|
```
|
|
Registered nodes: ['PrismAudioModelLoader', 'PrismAudioFeatureLoader', 'PrismAudioFeatureExtractor', 'PrismAudioSampler', 'PrismAudioTextOnly']
|
|
```
|
|
|
|
**Step 2: Fix any import errors iteratively**
|
|
|
|
Common issues:
|
|
- `prismaudio_core` internal imports may reference wrong module paths
|
|
- Missing model submodules in `prismaudio_core/models/`
|
|
- flash-attn fallback not properly guarded
|
|
|
|
**Step 3: Test model loading (requires GPU + model files)**
|
|
|
|
```bash
|
|
python -c "
|
|
from prismaudio_core.factory import create_model_from_config
|
|
import json
|
|
with open('prismaudio_core/configs/prismaudio.json') as f:
|
|
config = json.load(f)
|
|
model = create_model_from_config(config)
|
|
print('Model created, params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')
|
|
"
|
|
```
|
|
|
|
Expected: `Model created, params: ~518 M`
|
|
|
|
**Step 4: End-to-end test with pre-computed features**
|
|
|
|
If you have a `.npz` feature file from the PrismAudio repo's demo data, test the full pipeline in ComfyUI:
|
|
1. PrismAudioModelLoader -> PrismAudioFeatureLoader -> PrismAudioSampler -> Preview Audio node
|
|
|
|
**Step 5: Verify variable duration handling**
|
|
|
|
Test with multiple durations (5s, 10s, 20s) to ensure the model adapts to different
|
|
input shapes and produces audio of the expected length.
|
|
|
|
**Step 6: Commit**
|
|
|
|
```bash
|
|
git add -A
|
|
git commit -m "feat: integration fixes and verification"
|
|
```
|
|
|
|
---
|
|
|
|
### Task 11: README
|
|
|
|
**Files:**
|
|
- Create: `README.md`
|
|
|
|
**Step 1: Write README covering:**
|
|
|
|
- What PrismAudio is (brief, link to paper)
|
|
- Installation (clone, pip install requirements, optional extraction env setup)
|
|
- Node descriptions with input/output tables
|
|
- Example workflows (quality path with FeatureExtractor, quick path with FeatureLoader, text-only)
|
|
- HuggingFace authentication (2 methods: `HF_TOKEN` env var, `huggingface-cli login`)
|
|
- Note: hf_token is NOT a node widget for security reasons
|
|
- Which models may be gated (T5-Gemma, potentially Stable Audio VAE)
|
|
- Model file sizes: diffusion ~2.7GB, VAE ~2.5GB, synchformer ~950MB
|
|
- Extraction env setup via conda environment.yml
|
|
- Troubleshooting (VRAM, flash-attn optional, gated models)
|
|
- Credits and license
|
|
|
|
**Step 2: Commit**
|
|
|
|
```bash
|
|
git add README.md
|
|
git commit -m "docs: README with installation and usage instructions"
|
|
```
|
|
|
|
---
|
|
|
|
## Dependency Graph
|
|
|
|
```
|
|
Task 1 (scaffolding)
|
|
├── Task 2 (core config + factory) ──┐
|
|
│ └── Task 3 (core models) ──────┤
|
|
│ └── Task 4 (core sampling)┤
|
|
│ ├── Task 5 (ModelLoader node)
|
|
│ ├── Task 6 (FeatureLoader node)
|
|
│ ├── Task 7 (FeatureExtractor node)
|
|
│ ├── Task 8 (Sampler node)
|
|
│ └── Task 9 (TextOnly node)
|
|
└────────────────────────────────────────── Task 10 (Integration)
|
|
└── Task 11 (README)
|
|
```
|
|
|
|
Tasks 5-9 can be parallelized after Task 4 is complete. Task 3 is the heaviest — it involves extracting and adapting ~10 model files.
|