Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f60a9b2bf | |||
| 51f93f9688 | |||
| a315093743 | |||
| e49f760b77 | |||
| 4f40e15db3 | |||
| 08d73773c5 |
@@ -0,0 +1,167 @@
|
|||||||
|
# SelVA Integration Design
|
||||||
|
|
||||||
|
**Date:** 2026-04-04
|
||||||
|
**Branch:** feature/selva-integration (new from master)
|
||||||
|
**Status:** Approved, ready for implementation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
|
||||||
|
PrismAudio's sync conditioning is text-agnostic: Synchformer extracts features from
|
||||||
|
all visual motion equally. In multi-source videos (person walking near a car), the DiT
|
||||||
|
receives unfocused sync guidance and struggles to match audio events to the correct
|
||||||
|
visual source.
|
||||||
|
|
||||||
|
SelVA (CVPR 2026, arXiv:2512.02650) solves this with TextSynchformer — text conditioning
|
||||||
|
is injected inside the Synchformer encoder via cross-attention, so sync features only
|
||||||
|
encode motion relevant to the requested sound. This is the core architectural improvement
|
||||||
|
needed for reliable V2A sync.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### New directory layout
|
||||||
|
|
||||||
|
```
|
||||||
|
selva_core/ ← vendored SelVA source (model + ext + utils)
|
||||||
|
nodes/
|
||||||
|
selva_model_loader.py
|
||||||
|
selva_feature_extractor.py
|
||||||
|
selva_sampler.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### New custom types
|
||||||
|
|
||||||
|
- `SELVA_MODEL` — `{generator, video_enc, feature_utils, variant, strategy, dtype}`
|
||||||
|
- `SELVA_FEATURES` — `{clip_features, sync_features, duration}`
|
||||||
|
|
||||||
|
### No subprocess
|
||||||
|
|
||||||
|
SelVA is pure PyTorch. Feature extraction runs inline in ComfyUI — no managed venv,
|
||||||
|
no JAX/TF, no pip install on first run.
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
Zero new pip packages. ComfyUI already ships:
|
||||||
|
- `open_clip_torch` (CLIP ViT-H-14-384, auto-downloads via `hf-hub:` on first use)
|
||||||
|
- `transformers` (flan-t5-base, auto-downloads from HuggingFace on first use)
|
||||||
|
- `torch`, `torchaudio`, `einops`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Nodes
|
||||||
|
|
||||||
|
### `SelvaModelLoader` → `SELVA_MODEL`
|
||||||
|
|
||||||
|
| Input | Type | Default | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| variant | dropdown | medium_44k | small_16k / small_44k / medium_44k / large_44k |
|
||||||
|
| precision | dropdown | bf16 | bf16 / fp16 / fp32 |
|
||||||
|
| offload_strategy | dropdown | auto | auto / keep_in_vram / offload_to_cpu |
|
||||||
|
|
||||||
|
Resolves weights from `models/selva/`. Raises descriptive errors with download
|
||||||
|
instructions if files are missing.
|
||||||
|
|
||||||
|
### `SelvaFeatureExtractor` → `SELVA_FEATURES`, `FLOAT` (fps)
|
||||||
|
|
||||||
|
| Input | Type | Default | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| video | IMAGE | — | ComfyUI video tensor [T,H,W,C] |
|
||||||
|
| prompt | STRING | — | Used by TextSynchformer to select relevant motion |
|
||||||
|
| video_info | VHS_VIDEOINFO | opt | Auto-sets fps when connected |
|
||||||
|
| fps | FLOAT | 30.0 | Fallback fps if video_info not connected |
|
||||||
|
| cache_dir | STRING | "" | Empty = system temp dir |
|
||||||
|
|
||||||
|
Feature extraction steps (all inline, no subprocess):
|
||||||
|
1. Resize frames to 384×384 → CLIP video features `[B, T, 1024]`
|
||||||
|
2. Resize frames to 224×224 + encode prompt with flan-T5 → TextSynchformer → text-conditioned sync features `[B, T, 768]`
|
||||||
|
3. Save to `.npz` cache keyed by hash(frames[:1MB] + prompt + fps)
|
||||||
|
|
||||||
|
### `SelvaSampler` → `AUDIO`
|
||||||
|
|
||||||
|
| Input | Type | Default | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| model | SELVA_MODEL | — | |
|
||||||
|
| features | SELVA_FEATURES | — | |
|
||||||
|
| prompt | STRING | — | Should match extractor prompt; drives CLIP text guidance |
|
||||||
|
| negative_prompt | STRING | "" | Steers away from unwanted sounds |
|
||||||
|
| duration | FLOAT | 0.0 | 0 = auto from features duration |
|
||||||
|
| steps | INT | 25 | Euler steps (25 is SelVA default, fast) |
|
||||||
|
| cfg_strength | FLOAT | 4.5 | CFG scale (SelVA default) |
|
||||||
|
| seed | INT | 0 | |
|
||||||
|
|
||||||
|
Generation steps:
|
||||||
|
1. Encode prompt → CLIP text features (for MMAudio)
|
||||||
|
2. Encode negative prompt → empty conditions for CFG
|
||||||
|
3. `net_generator.preprocess_conditions(clip_f, sync_f, text_clip)`
|
||||||
|
4. Flow matching Euler ODE (`num_steps` iterations) with CFG
|
||||||
|
5. `feature_utils.decode(latent)` → mel spectrogram
|
||||||
|
6. `feature_utils.vocode(spec)` → waveform (BigVGAN for 16k, direct for 44k)
|
||||||
|
|
||||||
|
**Note on dual prompt:** The extractor prompt is baked into sync_features via
|
||||||
|
TextSynchformer at extraction time. The sampler prompt drives CLIP text conditioning
|
||||||
|
at generation time. They should match — a tooltip explains this.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
[VHS LoadVideo] ──► [SelvaFeatureExtractor]
|
||||||
|
│ prompt: "dog barking"
|
||||||
|
│ video_info: (fps auto)
|
||||||
|
▼
|
||||||
|
SELVA_FEATURES
|
||||||
|
{clip_features [B,T,1024],
|
||||||
|
sync_features [B,T,768], ← text-conditioned
|
||||||
|
duration: 8.2s}
|
||||||
|
│
|
||||||
|
[SelvaModelLoader] ──► [SelvaSampler]
|
||||||
|
variant: medium_44k │ prompt: "dog barking"
|
||||||
|
precision: bf16 │ negative: "wind noise"
|
||||||
|
│ cfg_strength: 4.5, steps: 25
|
||||||
|
▼
|
||||||
|
AUDIO (44.1kHz or 16kHz)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Weights
|
||||||
|
|
||||||
|
Location: `models/selva/`
|
||||||
|
|
||||||
|
```
|
||||||
|
video_enc_sup_5.pth ← TextSynch, shared across all variants
|
||||||
|
generator_small_16k_sup_5.pth
|
||||||
|
generator_small_44k_sup_5.pth
|
||||||
|
generator_medium_44k_sup_5.pth
|
||||||
|
generator_large_44k_sup_5.pth
|
||||||
|
ext/
|
||||||
|
v1-16.pth ← VAE for 16k variants
|
||||||
|
v1-44.pth ← VAE for 44k variants
|
||||||
|
best_netG.pt ← BigVGAN vocoder (16k only)
|
||||||
|
```
|
||||||
|
|
||||||
|
`synchformer_state_dict.pth` is reused from `models/prismaudio/` — no duplicate.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## selva_core vendoring
|
||||||
|
|
||||||
|
Copy from `jnwnlee/selva` (pinned to a specific commit for stability):
|
||||||
|
- `selva_core/model/` — MMAudio, TextSynch, transformer layers, embeddings, flow matching
|
||||||
|
- `selva_core/ext/` — autoencoder, BigVGAN, synchformer, rotary embeddings, mel converters
|
||||||
|
- `selva_core/utils/` — transforms, generate() helper
|
||||||
|
|
||||||
|
Rename all internal imports from `selva.*` → `selva_core.*`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What stays the same
|
||||||
|
|
||||||
|
- All PrismAudio nodes unchanged
|
||||||
|
- `models/prismaudio/` unchanged
|
||||||
|
- Synchformer checkpoint shared (not duplicated)
|
||||||
|
- Branch: new `feature/selva-integration` off master (LoRA work stays separate)
|
||||||
@@ -0,0 +1,738 @@
|
|||||||
|
# SelVA Integration Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Add three new ComfyUI nodes (SelvaModelLoader, SelvaFeatureExtractor, SelvaSampler) that run SelVA's text-conditioned V2A pipeline inline — no subprocess, no JAX, pure PyTorch.
|
||||||
|
|
||||||
|
**Architecture:** Vendor SelVA source into `selva_core/`, implement three nodes that mirror the PrismAudio pattern. `SelvaFeatureExtractor` takes `SELVA_MODEL` (needs TextSynchformer + CLIP/T5 from FeaturesUtils). `SelvaSampler` runs flow matching ODE with CFG and negative prompts.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, open_clip (already in ComfyUI), transformers (already in ComfyUI), torchaudio, einops, torchvision
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Design reference
|
||||||
|
|
||||||
|
`docs/plans/2026-04-04-selva-integration-design.md`
|
||||||
|
|
||||||
|
**Key facts from SelVA source:**
|
||||||
|
- CLIP input: `[B, T, C, 384, 384]` float32 `[0,1]` — normalization applied inside FeaturesUtils
|
||||||
|
- Sync input: `[B, T, C, 224, 224]` float32 `[-1,1]` — normalize with `mean=std=[0.5,0.5,0.5]` before passing
|
||||||
|
- CLIP frame rate: 8fps, Sync frame rate: 25fps
|
||||||
|
- CONFIG_16K: latent=250, clip=64, sync=192 at 8s
|
||||||
|
- CONFIG_44K: latent=345, clip=64, sync=192 at 8s
|
||||||
|
- Sync segments: 16-frame windows, 8-frame stride (overlapping, unlike PrismAudio's 8-frame non-overlapping)
|
||||||
|
- `net_generator.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)` must be called before each generation when duration ≠ 8s
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 1: Create branch and vendor selva_core
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `selva_core/` (full directory tree)
|
||||||
|
|
||||||
|
**Step 1: Create new branch off master (not off feature/lora-trainer)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout master
|
||||||
|
git checkout -b feature/selva-integration
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Clone SelVA and copy source**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/jnwnlee/selva.git /tmp/selva_src
|
||||||
|
cp -r /tmp/selva_src/selva /media/p5/Comfyui-Prismaudio/selva_core
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Rename all internal imports**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio/selva_core
|
||||||
|
find . -name "*.py" -exec sed -i \
|
||||||
|
's/from selva\./from selva_core./g;
|
||||||
|
s/import selva\./import selva_core./g' {} \;
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Record the pinned commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /tmp/selva_src && git rev-parse HEAD
|
||||||
|
# Paste the hash into a comment at the top of selva_core/__init__.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit `selva_core/__init__.py` to add at the top:
|
||||||
|
```python
|
||||||
|
# Vendored from https://github.com/jnwnlee/selva
|
||||||
|
# Pinned commit: <PASTE_HASH_HERE>
|
||||||
|
# Imports rewritten from selva.* → selva_core.*
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Verify imports work**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio
|
||||||
|
python -c "
|
||||||
|
from selva_core.model.networks_generator import MMAudio, get_my_mmaudio
|
||||||
|
from selva_core.model.networks_video_enc import TextSynch, get_my_textsynch
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
||||||
|
print('selva_core imports OK')
|
||||||
|
print(f'CONFIG_16K: latent={CONFIG_16K.latent_seq_len} clip={CONFIG_16K.clip_seq_len} sync={CONFIG_16K.sync_seq_len}')
|
||||||
|
print(f'CONFIG_44K: latent={CONFIG_44K.latent_seq_len} clip={CONFIG_44K.clip_seq_len} sync={CONFIG_44K.sync_seq_len}')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected:
|
||||||
|
```
|
||||||
|
selva_core imports OK
|
||||||
|
CONFIG_16K: latent=250 clip=64 sync=192
|
||||||
|
CONFIG_44K: latent=345 clip=64 sync=192
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add selva_core/
|
||||||
|
git commit -m "chore: vendor selva_core from jnwnlee/selva@<HASH>
|
||||||
|
|
||||||
|
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
|
||||||
|
Imports rewritten from selva.* to selva_core.*. No training code included."
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 2: Implement SelvaModelLoader
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `nodes/selva_model_loader.py`
|
||||||
|
- Modify: `nodes/__init__.py`
|
||||||
|
|
||||||
|
**Step 1: Create `nodes/selva_model_loader.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy
|
||||||
|
|
||||||
|
# Variant → (generator filename, mode, has_bigvgan)
|
||||||
|
_VARIANTS = {
|
||||||
|
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
||||||
|
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
||||||
|
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
||||||
|
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
_SELVA_DIR = os.path.join(folder_paths.models_dir, "selva")
|
||||||
|
|
||||||
|
|
||||||
|
def _selva_path(*parts):
|
||||||
|
return os.path.join(_SELVA_DIR, *parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _require(path, hint):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[SelVA] Missing: {path}\n{hint}"
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaModelLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"variant": (list(_VARIANTS.keys()),),
|
||||||
|
"precision": (["bf16", "fp16", "fp32"],),
|
||||||
|
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_model(self, variant, precision, offload_strategy):
|
||||||
|
from selva_core.model.networks_generator import get_my_mmaudio
|
||||||
|
from selva_core.model.networks_video_enc import get_my_textsynch
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
||||||
|
|
||||||
|
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
||||||
|
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
strategy = determine_offload_strategy(offload_strategy)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Resolve weight paths
|
||||||
|
video_enc_path = _require(
|
||||||
|
_selva_path("video_enc_sup_5.pth"),
|
||||||
|
"Download from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
||||||
|
)
|
||||||
|
gen_path = _require(
|
||||||
|
_selva_path(gen_filename),
|
||||||
|
f"Download {gen_filename} from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
||||||
|
)
|
||||||
|
vae_path = _require(
|
||||||
|
_selva_path("ext", f"v1-{mode}.pth"),
|
||||||
|
f"Download v1-{mode}.pth from MMAudio/SelVA release and place in models/selva/ext/"
|
||||||
|
)
|
||||||
|
synch_path = _require(
|
||||||
|
os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth"),
|
||||||
|
"Synchformer checkpoint missing from models/prismaudio/ — download from FunAudioLLM/PrismAudio"
|
||||||
|
)
|
||||||
|
bigvgan_path = None
|
||||||
|
if has_bigvgan:
|
||||||
|
bigvgan_path = _require(
|
||||||
|
_selva_path("ext", "best_netG.pt"),
|
||||||
|
"Download best_netG.pt (BigVGAN 16k vocoder) from MMAudio release and place in models/selva/ext/"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
||||||
|
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||||
|
net_video_enc.load_weights(
|
||||||
|
torch.load(video_enc_path, map_location="cpu", weights_only=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
|
||||||
|
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
||||||
|
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
|
||||||
|
net_generator.load_weights(
|
||||||
|
torch.load(gen_path, map_location="cpu", weights_only=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
|
||||||
|
feature_utils = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=vae_path,
|
||||||
|
synchformer_ckpt=synch_path,
|
||||||
|
enable_conditions=True,
|
||||||
|
mode=mode,
|
||||||
|
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||||
|
).to(device, dtype).eval()
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(get_offload_device())
|
||||||
|
net_video_enc.to(get_offload_device())
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
|
||||||
|
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
|
||||||
|
|
||||||
|
return ({
|
||||||
|
"generator": net_generator,
|
||||||
|
"video_enc": net_video_enc,
|
||||||
|
"feature_utils": feature_utils,
|
||||||
|
"variant": variant,
|
||||||
|
"mode": mode,
|
||||||
|
"strategy": strategy,
|
||||||
|
"dtype": dtype,
|
||||||
|
"seq_cfg": seq_cfg,
|
||||||
|
},)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Register in `nodes/__init__.py`**
|
||||||
|
|
||||||
|
In the `NODE_CLASS_MAPPINGS` dict, add:
|
||||||
|
```python
|
||||||
|
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Verify node registers**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_model_loader import SelvaModelLoader
|
||||||
|
print('inputs:', list(SelvaModelLoader.INPUT_TYPES()['required'].keys()))
|
||||||
|
print('outputs:', SelvaModelLoader.RETURN_TYPES)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `inputs: ['variant', 'precision', 'offload_strategy']`
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes/selva_model_loader.py nodes/__init__.py
|
||||||
|
git commit -m "feat: SelvaModelLoader node — loads TextSynch + MMAudio + FeaturesUtils"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 3: Implement SelvaFeatureExtractor
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `nodes/selva_feature_extractor.py`
|
||||||
|
- Modify: `nodes/__init__.py`
|
||||||
|
|
||||||
|
**Step 1: Create `nodes/selva_feature_extractor.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
|
|
||||||
|
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||||
|
_CLIP_SIZE = 384
|
||||||
|
_SYNC_SIZE = 224
|
||||||
|
_CLIP_FPS = 8
|
||||||
|
_SYNC_FPS = 25
|
||||||
|
|
||||||
|
# Sync normalization: [-1, 1] (from selva/utils/eval_utils.py load_video)
|
||||||
|
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||||
|
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_frames(video, source_fps, target_fps, duration):
|
||||||
|
"""Sample frames from [T,H,W,C] float32 [0,1] at target_fps."""
|
||||||
|
T = video.shape[0]
|
||||||
|
n_out = max(1, int(duration * target_fps))
|
||||||
|
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
||||||
|
return video[indices] # [N, H, W, C]
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_frames(frames, size):
|
||||||
|
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
|
||||||
|
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
|
||||||
|
x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False)
|
||||||
|
return x.clamp(0, 1) # [N, C, H, W] float32
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_inputs(video_tensor, prompt, fps, variant):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024])
|
||||||
|
h.update(prompt.encode())
|
||||||
|
h.update(str(fps).encode())
|
||||||
|
h.update(variant.encode())
|
||||||
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaFeatureExtractor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"video": ("IMAGE",),
|
||||||
|
"prompt": ("STRING", {"default": "", "multiline": True,
|
||||||
|
"tooltip": "Text prompt used by TextSynchformer to focus sync features on the relevant sound source. Should match the prompt used in SelvaSampler."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info to auto-set fps."}),
|
||||||
|
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||||
|
"tooltip": "Override duration in seconds. 0 = infer from video length and fps."}),
|
||||||
|
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory for cached .npz features. Empty = temp dir."}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT")
|
||||||
|
RETURN_NAMES = ("features", "fps")
|
||||||
|
FUNCTION = "extract_features"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||||
|
duration=0.0, cache_dir=""):
|
||||||
|
if video_info is not None:
|
||||||
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
|
T = video.shape[0]
|
||||||
|
if duration <= 0:
|
||||||
|
duration = T / fps
|
||||||
|
duration = min(duration, T / fps) # clamp to actual video length
|
||||||
|
|
||||||
|
if not prompt.strip():
|
||||||
|
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
cache_key = _hash_inputs(video, prompt, fps, model["variant"])
|
||||||
|
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||||
|
|
||||||
|
if os.path.exists(cached_path):
|
||||||
|
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
||||||
|
return (_load_cached(cached_path), float(fps))
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
net_video_enc = model["video_enc"]
|
||||||
|
|
||||||
|
# Move feature models to device
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(device)
|
||||||
|
net_video_enc.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# --- CLIP frames: 384×384, [0,1], 8fps ---
|
||||||
|
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||||
|
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||||
|
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||||
|
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps", flush=True)
|
||||||
|
|
||||||
|
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||||
|
|
||||||
|
# --- Sync frames: 224×224, [-1,1], 25fps ---
|
||||||
|
n_sync = max(16, int(duration * _SYNC_FPS)) # minimum 16 for segmentation
|
||||||
|
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration)
|
||||||
|
if sync_frames.shape[0] < 16:
|
||||||
|
# Pad by repeating last frame to reach minimum 16
|
||||||
|
pad = 16 - sync_frames.shape[0]
|
||||||
|
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||||
|
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||||
|
# Normalize to [-1, 1]
|
||||||
|
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||||
|
std = _SYNC_STD.to(sync_frames.device)
|
||||||
|
sync_frames = (sync_frames - mean) / std
|
||||||
|
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||||
|
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps", flush=True)
|
||||||
|
|
||||||
|
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||||
|
text_f_t5, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, 768], [1, L]
|
||||||
|
text_f_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_f_t5, text_mask)
|
||||||
|
sync_features = net_video_enc.encode_video_with_sync(
|
||||||
|
sync_input, text_f=text_f_t5, text_mask=text_mask
|
||||||
|
) # [1, T_sync, 768]
|
||||||
|
|
||||||
|
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||||
|
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||||
|
|
||||||
|
# Offload back if needed
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
net_video_enc.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Save cache
|
||||||
|
np.savez(
|
||||||
|
cached_path,
|
||||||
|
clip_features=clip_features.cpu().float().numpy(),
|
||||||
|
sync_features=sync_features.cpu().float().numpy(),
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"clip_features": clip_features.cpu(),
|
||||||
|
"sync_features": sync_features.cpu(),
|
||||||
|
"duration": duration,
|
||||||
|
}
|
||||||
|
return (features, float(fps))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cached(path):
|
||||||
|
data = np.load(path, allow_pickle=False)
|
||||||
|
return {
|
||||||
|
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||||
|
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||||
|
"duration": float(data["duration"]),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Register in `nodes/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Verify node registers**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_feature_extractor import SelvaFeatureExtractor
|
||||||
|
inputs = SelvaFeatureExtractor.INPUT_TYPES()
|
||||||
|
print('required:', list(inputs['required'].keys()))
|
||||||
|
print('optional:', list(inputs['optional'].keys()))
|
||||||
|
print('outputs:', SelvaFeatureExtractor.RETURN_TYPES)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `required: ['model', 'video', 'prompt']`
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes/selva_feature_extractor.py nodes/__init__.py
|
||||||
|
git commit -m "feat: SelvaFeatureExtractor — inline CLIP + TextSynchformer feature extraction"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 4: Implement SelvaSampler
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `nodes/selva_sampler.py`
|
||||||
|
- Modify: `nodes/__init__.py`
|
||||||
|
|
||||||
|
**Step 1: Create `nodes/selva_sampler.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY,
|
||||||
|
get_device, get_offload_device, soft_empty_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seq_cfg(duration, mode):
|
||||||
|
"""Compute sequence lengths for a given duration and mode."""
|
||||||
|
from selva_core.model.sequence_config import SequenceConfig
|
||||||
|
if mode == "16k":
|
||||||
|
return SequenceConfig(duration=duration, sampling_rate=16000, spectrogram_frame_rate=256)
|
||||||
|
else:
|
||||||
|
return SequenceConfig(duration=duration, sampling_rate=44100, spectrogram_frame_rate=512)
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaSampler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"features": ("SELVA_FEATURES",),
|
||||||
|
"prompt": ("STRING", {"default": "", "multiline": True,
|
||||||
|
"tooltip": "Should match the prompt used in SelvaFeatureExtractor."}),
|
||||||
|
"negative_prompt": ("STRING", {"default": "", "multiline": True,
|
||||||
|
"tooltip": "Sounds to steer away from, e.g. 'wind noise, background music'."}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||||
|
"tooltip": "Audio duration in seconds. 0 = use duration from features."}),
|
||||||
|
"steps": ("INT", {"default": 25, "min": 1, "max": 200}),
|
||||||
|
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
net_generator = model["generator"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
mode = model["mode"]
|
||||||
|
|
||||||
|
# Resolve duration
|
||||||
|
if duration <= 0:
|
||||||
|
if "duration" not in features:
|
||||||
|
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
|
||||||
|
duration = features["duration"]
|
||||||
|
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
seq_cfg = _make_seq_cfg(duration, mode)
|
||||||
|
sample_rate = seq_cfg.sampling_rate
|
||||||
|
|
||||||
|
# Move models to device
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(device)
|
||||||
|
feature_utils.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||||
|
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||||
|
|
||||||
|
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||||
|
print(f"[SelVA] seq_cfg: latent={seq_cfg.latent_seq_len} clip={seq_cfg.clip_seq_len} sync={seq_cfg.sync_seq_len}", flush=True)
|
||||||
|
|
||||||
|
# Update model sequence lengths for this duration
|
||||||
|
net_generator.update_seq_lengths(
|
||||||
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
|
clip_seq_len=seq_cfg.clip_seq_len,
|
||||||
|
sync_seq_len=seq_cfg.sync_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode text
|
||||||
|
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||||
|
|
||||||
|
# Build empty (negative) conditions for CFG
|
||||||
|
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||||
|
if negative_prompt.strip() else None
|
||||||
|
|
||||||
|
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
|
empty_conditions = net_generator.get_empty_conditions(
|
||||||
|
bs=1, negative_text_features=neg_text_clip
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample initial noise
|
||||||
|
rng = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
x0 = torch.randn(
|
||||||
|
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||||
|
device=device, dtype=dtype, generator=rng
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flow matching ODE (Euler)
|
||||||
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
_step_count = [0]
|
||||||
|
orig_to_data = fm.to_data
|
||||||
|
|
||||||
|
def tracked_to_data(fn, x0_):
|
||||||
|
# ProgressBar update via step counting in ode_wrapper
|
||||||
|
return orig_to_data(fn, x0_)
|
||||||
|
|
||||||
|
# Wrap ODE to update progress bar
|
||||||
|
def ode_wrapper_tracked(t, x):
|
||||||
|
_step_count[0] += 1
|
||||||
|
pbar.update(1)
|
||||||
|
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||||
|
|
||||||
|
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||||
|
|
||||||
|
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||||
|
|
||||||
|
# Decode: latent → mel → audio
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
x1_unnorm = net_generator.unnormalize(x1)
|
||||||
|
spec = feature_utils.decode(x1_unnorm)
|
||||||
|
audio = feature_utils.vocode(spec) # [1, samples] or [1, 1, samples]
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(get_offload_device())
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Normalise to [-1, 1]
|
||||||
|
audio = audio.float()
|
||||||
|
if audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(1) # [1, 1, samples]
|
||||||
|
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||||
|
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||||
|
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||||
|
|
||||||
|
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Register in `nodes/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Verify node registers**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_sampler import SelvaSampler
|
||||||
|
inputs = SelvaSampler.INPUT_TYPES()
|
||||||
|
print('inputs:', list(inputs['required'].keys()))
|
||||||
|
print('outputs:', SelvaSampler.RETURN_TYPES)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `inputs: ['model', 'features', 'prompt', 'negative_prompt', 'duration', 'steps', 'cfg_strength', 'seed']`
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes/selva_sampler.py nodes/__init__.py
|
||||||
|
git commit -m "feat: SelvaSampler — flow matching ODE with CFG + negative prompts"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 5: Create example workflow and push
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `workflows/selva_video_to_audio.json`
|
||||||
|
|
||||||
|
**Step 1: Create workflow JSON**
|
||||||
|
|
||||||
|
Create `workflows/selva_video_to_audio.json` with this node graph:
|
||||||
|
- LoadVideo (VHS) → IMAGE + VHS_VIDEOINFO
|
||||||
|
- SelvaModelLoader → SELVA_MODEL
|
||||||
|
- SelvaFeatureExtractor (takes IMAGE + VHS_VIDEOINFO + SELVA_MODEL, prompt) → SELVA_FEATURES
|
||||||
|
- SelvaSampler (takes SELVA_MODEL + SELVA_FEATURES, prompt, negative_prompt) → AUDIO
|
||||||
|
- PreviewAudio (takes AUDIO)
|
||||||
|
|
||||||
|
Set defaults: variant=medium_44k, precision=bf16, steps=25, cfg_strength=4.5, duration=0.
|
||||||
|
|
||||||
|
**Step 2: Push branch**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git push -u origin feature/selva-integration
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 6: Smoke test
|
||||||
|
|
||||||
|
**Step 1: Check all three nodes are importable from ComfyUI's perspective**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
import nodes
|
||||||
|
m = nodes.NODE_CLASS_MAPPINGS
|
||||||
|
print('SelVA nodes:', [k for k in m if 'Selva' in k])
|
||||||
|
assert 'SelvaModelLoader' in m
|
||||||
|
assert 'SelvaFeatureExtractor' in m
|
||||||
|
assert 'SelvaSampler' in m
|
||||||
|
print('All SelVA nodes registered OK')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Verify no import errors in full node load**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_model_loader import SelvaModelLoader
|
||||||
|
from nodes.selva_feature_extractor import SelvaFeatureExtractor
|
||||||
|
from nodes.selva_sampler import SelvaSampler
|
||||||
|
print('All imports clean')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Final commit with any fixes**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add -A
|
||||||
|
git commit -m "fix: selva integration smoke test fixes (if any)"
|
||||||
|
git push
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The `FeaturesUtils.train()` is overridden to always call `super().train(False)` — SelVA models are always in eval mode
|
||||||
|
- `net_generator.update_seq_lengths` recalculates rotary position embeddings; call it before every generation when duration may vary
|
||||||
|
- ProgressBar tracking: `FlowMatching.to_data` calls `fn(t, x)` for each Euler step; wrapping `ode_wrapper` with a counter gives accurate progress
|
||||||
|
- The `feature_utils.vocode` returns audio at 16kHz for small_16k (uses BigVGAN) and 44.1kHz for 44k variants (uses VAE mel decoder directly)
|
||||||
|
- If `encode_text_t5` or `encode_text_clip` fail with missing model errors on first run, it's HuggingFace downloading `flan-t5-base` and `apple/DFN5B-CLIP-ViT-H-14-384` — this is expected and takes a few minutes once
|
||||||
@@ -7,6 +7,8 @@ _NODES = {
|
|||||||
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
||||||
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
||||||
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
||||||
|
"PrismAudioLoRATrainer": (".lora_trainer", "PrismAudioLoRATrainer", "PrismAudio LoRA Trainer"),
|
||||||
|
"PrismAudioLoRALoader": (".lora_loader", "PrismAudioLoRALoader", "PrismAudio LoRA Loader"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -13,13 +13,29 @@ _PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
|||||||
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
||||||
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
||||||
|
|
||||||
|
def _jax_package():
|
||||||
|
"""Return the correct jax extra for the current CUDA version."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
cuda_ver = torch.version.cuda or ""
|
||||||
|
major = int(cuda_ver.split(".")[0]) if cuda_ver else 0
|
||||||
|
if major >= 13:
|
||||||
|
return "jax[cuda13]"
|
||||||
|
elif major >= 12:
|
||||||
|
return "jax[cuda12]"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "jax" # CPU fallback
|
||||||
|
|
||||||
|
|
||||||
_EXTRACT_PACKAGES = [
|
_EXTRACT_PACKAGES = [
|
||||||
"torch", "torchaudio", "torchvision",
|
"torch", "torchaudio", "torchvision",
|
||||||
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
||||||
"tensorflow-cpu>=2.16.0",
|
"tensorflow-cpu>=2.16.0",
|
||||||
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
|
# jax CUDA extra is resolved at install time based on detected CUDA version
|
||||||
"jax[cuda13]", "flax",
|
_jax_package(), "flax",
|
||||||
"transformers", "decord", "einops", "numpy", "mediapy",
|
"transformers", "decord", "einops", "numpy",
|
||||||
"git+https://github.com/google-deepmind/videoprism.git",
|
"git+https://github.com/google-deepmind/videoprism.git",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -70,11 +86,12 @@ def _ensure_extract_env():
|
|||||||
return _MANAGED_PYTHON
|
return _MANAGED_PYTHON
|
||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, cot_text):
|
def _hash_inputs(video_tensor, cot_text, fps):
|
||||||
"""Create a hash of the inputs for caching."""
|
"""Create a hash of the inputs for caching."""
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||||
h.update(cot_text.encode())
|
h.update(cot_text.encode())
|
||||||
|
h.update(str(fps).encode()) # fps affects frame sampling — must be part of the key
|
||||||
return h.hexdigest()[:16]
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
@@ -115,6 +132,10 @@ class PrismAudioFeatureExtractor:
|
|||||||
if video_info is not None:
|
if video_info is not None:
|
||||||
fps = video_info["loaded_fps"]
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
|
if not caption_cot.strip():
|
||||||
|
print("[PrismAudio] Warning: caption_cot is empty — text features will be degenerate. "
|
||||||
|
"Provide a descriptive chain-of-thought caption for best results.", flush=True)
|
||||||
|
|
||||||
# Resolve python binary
|
# Resolve python binary
|
||||||
if python_env == "comfyui_env":
|
if python_env == "comfyui_env":
|
||||||
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
||||||
@@ -129,7 +150,7 @@ class PrismAudioFeatureExtractor:
|
|||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
# Check cache
|
# Check cache
|
||||||
cache_hash = _hash_inputs(video, caption_cot)
|
cache_hash = _hash_inputs(video, caption_cot, fps)
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||||
if os.path.exists(cached_path):
|
if os.path.exists(cached_path):
|
||||||
print(f"[PrismAudio] Using cached features: {cached_path}")
|
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_lora_weights(dit: nn.Module, lora_state: dict, rank: int, alpha: float, strength: float):
|
||||||
|
"""Add LoRA delta weights directly into the base model's nn.Linear tensors.
|
||||||
|
|
||||||
|
delta_W = lora_B @ lora_A * scale * strength
|
||||||
|
applied as: linear.weight += delta_W
|
||||||
|
|
||||||
|
This is equivalent to LoRALinear at inference but requires no wrapper,
|
||||||
|
no extra memory, and no change to the model's forward call graph.
|
||||||
|
"""
|
||||||
|
scale = (alpha / rank) * strength
|
||||||
|
|
||||||
|
# Group saved keys by module path
|
||||||
|
a_map = {
|
||||||
|
k.replace(".lora_A.weight", ""): v
|
||||||
|
for k, v in lora_state.items() if k.endswith("lora_A.weight")
|
||||||
|
}
|
||||||
|
b_map = {
|
||||||
|
k.replace(".lora_B.weight", ""): v
|
||||||
|
for k, v in lora_state.items() if k.endswith("lora_B.weight")
|
||||||
|
}
|
||||||
|
|
||||||
|
merged = 0
|
||||||
|
for path, lora_A in a_map.items():
|
||||||
|
if path not in b_map:
|
||||||
|
print(f"[PrismAudio] LoRA merge: missing lora_B for {path}, skipping", flush=True)
|
||||||
|
continue
|
||||||
|
lora_B = b_map[path] # [out_features, rank]
|
||||||
|
# delta_W: [out_features, in_features]
|
||||||
|
delta_W = (lora_B.float() @ lora_A.float()) * scale
|
||||||
|
|
||||||
|
# Navigate to the parent module using PyTorch's get_submodule
|
||||||
|
*parent_parts, child_name = path.split(".")
|
||||||
|
try:
|
||||||
|
parent = dit.get_submodule(".".join(parent_parts)) if parent_parts else dit
|
||||||
|
except AttributeError as e:
|
||||||
|
print(f"[PrismAudio] LoRA merge: could not find module '{path}': {e}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
linear = getattr(parent, child_name, None)
|
||||||
|
if not isinstance(linear, nn.Linear):
|
||||||
|
print(f"[PrismAudio] LoRA merge: expected nn.Linear at '{path}', got {type(linear)}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
linear.weight.data.add_(delta_W.to(linear.weight.dtype))
|
||||||
|
merged += 1
|
||||||
|
|
||||||
|
print(f"[PrismAudio] LoRA merged {merged} layer(s) (strength={strength:.3f})", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioLoRALoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"lora_path": ("STRING", {"default": "", "tooltip": "Path to .safetensors LoRA file produced by PrismAudio LoRA Trainer"}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "LoRA influence scale. 1.0 = full strength, 0.0 = base model only"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_lora"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_lora(self, model, lora_path, strength):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
if not os.path.exists(lora_path):
|
||||||
|
raise FileNotFoundError(f"[PrismAudio] LoRA file not found: {lora_path}")
|
||||||
|
|
||||||
|
config_path = lora_path.replace(".safetensors", "_config.json")
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[PrismAudio] LoRA config not found: {config_path}\n"
|
||||||
|
"Expected a _config.json alongside the .safetensors file."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
rank = config["rank"]
|
||||||
|
alpha = config["alpha"]
|
||||||
|
|
||||||
|
lora_state = load_file(lora_path)
|
||||||
|
|
||||||
|
# Merge LoRA weights in-place into the DiT's base linear layers.
|
||||||
|
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
||||||
|
# when inputs change, providing a fresh base model as needed.
|
||||||
|
dit = model["model"].model # DiTWrapper
|
||||||
|
|
||||||
|
if strength == 0.0:
|
||||||
|
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
_merge_lora_weights(dit, lora_state, rank, alpha, strength)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
@@ -0,0 +1,284 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE,
|
||||||
|
get_device, get_offload_device, soft_empty_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoRA primitives
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
"""Low-rank adapter wrapping a frozen nn.Linear."""
|
||||||
|
|
||||||
|
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = linear
|
||||||
|
self.scale = alpha / rank
|
||||||
|
in_f, out_f = linear.in_features, linear.out_features
|
||||||
|
self.lora_A = nn.Linear(in_f, rank, bias=False)
|
||||||
|
self.lora_B = nn.Linear(rank, out_f, bias=False)
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
_TARGET_MODULE_PRESETS = {
|
||||||
|
"attn_only": {"to_q", "to_kv", "to_qkv", "to_out"},
|
||||||
|
"attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"},
|
||||||
|
"full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float):
|
||||||
|
"""Recursively replace matching nn.Linear layers with LoRALinear."""
|
||||||
|
for name, child in list(module.named_children()):
|
||||||
|
if isinstance(child, nn.Linear) and name in target_attrs:
|
||||||
|
setattr(module, name, LoRALinear(child, rank, alpha))
|
||||||
|
else:
|
||||||
|
_apply_lora(child, target_attrs, rank, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _unapply_lora(module: nn.Module):
|
||||||
|
"""Replace LoRALinear back with the original frozen Linear (no weight merge)."""
|
||||||
|
for name, child in list(module.named_children()):
|
||||||
|
if isinstance(child, LoRALinear):
|
||||||
|
child.linear.weight.requires_grad_(False)
|
||||||
|
setattr(module, name, child.linear)
|
||||||
|
else:
|
||||||
|
_unapply_lora(child)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_state_dict(module: nn.Module) -> dict:
|
||||||
|
"""Return only LoRA parameter tensors from a module's state dict."""
|
||||||
|
return {k: v for k, v in module.state_dict().items()
|
||||||
|
if "lora_A" in k or "lora_B" in k}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_AUDIO_EXTS = (".wav", ".flac", ".mp3")
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_dataset(dataset_dir: str):
|
||||||
|
"""Return list of (npz_path, audio_path) pairs matched by stem."""
|
||||||
|
pairs = []
|
||||||
|
for fname in os.listdir(dataset_dir):
|
||||||
|
if not fname.endswith(".npz"):
|
||||||
|
continue
|
||||||
|
stem = os.path.join(dataset_dir, fname[:-4])
|
||||||
|
for ext in _AUDIO_EXTS:
|
||||||
|
audio_path = stem + ext
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
pairs.append((stem + ".npz", audio_path))
|
||||||
|
break
|
||||||
|
return sorted(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor:
|
||||||
|
"""Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE."""
|
||||||
|
import torchaudio
|
||||||
|
waveform, sr = torchaudio.load(audio_path)
|
||||||
|
if sr != SAMPLE_RATE:
|
||||||
|
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
|
||||||
|
if waveform.shape[0] == 1:
|
||||||
|
waveform = waveform.expand(2, -1)
|
||||||
|
elif waveform.shape[0] > 2:
|
||||||
|
waveform = waveform[:2]
|
||||||
|
return waveform.unsqueeze(0).to(device) # [1, 2, samples]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict:
|
||||||
|
"""Load .npz features into a conditioner metadata dict."""
|
||||||
|
import numpy as np
|
||||||
|
data = np.load(npz_path, allow_pickle=True)
|
||||||
|
video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype)
|
||||||
|
text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype)
|
||||||
|
sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype)
|
||||||
|
has_video = bool(video_feat.abs().sum() > 0)
|
||||||
|
return {
|
||||||
|
"video_features": video_feat,
|
||||||
|
"text_features": text_feat,
|
||||||
|
"sync_features": sync_feat,
|
||||||
|
"video_exist": torch.tensor(has_video),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trainer node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class PrismAudioLoRATrainer:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}),
|
||||||
|
"output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}),
|
||||||
|
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}),
|
||||||
|
"lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}),
|
||||||
|
"target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}),
|
||||||
|
"learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}),
|
||||||
|
"train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}),
|
||||||
|
"cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}),
|
||||||
|
"save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("lora_path",)
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha,
|
||||||
|
target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed):
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
diffusion = model["model"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Scan dataset
|
||||||
|
pairs = _scan_dataset(dataset_dir)
|
||||||
|
if not pairs:
|
||||||
|
raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}")
|
||||||
|
print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True)
|
||||||
|
|
||||||
|
# Resolve output path
|
||||||
|
if not output_path:
|
||||||
|
import folder_paths
|
||||||
|
out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora")
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors")
|
||||||
|
|
||||||
|
# Move model to device
|
||||||
|
diffusion.model.to(device)
|
||||||
|
diffusion.conditioner.to(device)
|
||||||
|
diffusion.pretransform.to(device)
|
||||||
|
|
||||||
|
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||||
|
dit = diffusion.model # DiTWrapper
|
||||||
|
for p in dit.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
target_attrs = _TARGET_MODULE_PRESETS[target_modules]
|
||||||
|
_apply_lora(dit, target_attrs, lora_rank, lora_alpha)
|
||||||
|
|
||||||
|
# Cast LoRA params to model dtype and move to device
|
||||||
|
for m in dit.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_A.to(device=device, dtype=dtype)
|
||||||
|
m.lora_B.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
trainable = [p for p in dit.parameters() if p.requires_grad]
|
||||||
|
n_params = sum(p.numel() for p in trainable)
|
||||||
|
print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True)
|
||||||
|
|
||||||
|
diffusion.conditioner.eval()
|
||||||
|
diffusion.pretransform.eval()
|
||||||
|
dit.train()
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(trainable, lr=learning_rate)
|
||||||
|
|
||||||
|
# GradScaler for fp16 to prevent underflow
|
||||||
|
use_scaler = (dtype == torch.float16)
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
||||||
|
|
||||||
|
pbar = comfy.utils.ProgressBar(train_steps)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for step in range(1, train_steps + 1):
|
||||||
|
npz_path, audio_path = random.choice(pairs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode audio to latent space
|
||||||
|
audio = _load_audio(audio_path, device)
|
||||||
|
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||||
|
|
||||||
|
# Build conditioning from features
|
||||||
|
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||||
|
conditioning = diffusion.conditioner(metadata, device)
|
||||||
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||||
|
|
||||||
|
# Rectified flow: interpolate between data and noise
|
||||||
|
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||||
|
t_bcast = t[:, None, None]
|
||||||
|
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||||
|
v_target = noise - x0
|
||||||
|
|
||||||
|
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||||
|
v_pred = dit(x_t, t,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob=cfg_dropout_prob,
|
||||||
|
**cond_inputs)
|
||||||
|
|
||||||
|
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||||
|
|
||||||
|
if use_scaler:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % 50 == 0:
|
||||||
|
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||||
|
|
||||||
|
if step % save_every == 0:
|
||||||
|
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||||
|
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||||
|
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Save final weights
|
||||||
|
save_file(_get_lora_state_dict(dit), output_path)
|
||||||
|
|
||||||
|
# Save config alongside weights so the loader knows the structure
|
||||||
|
config_path = output_path.replace(".safetensors", "_config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"rank": lora_rank,
|
||||||
|
"alpha": lora_alpha,
|
||||||
|
"target_modules": sorted(target_attrs),
|
||||||
|
}, f, indent=2)
|
||||||
|
|
||||||
|
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Always restore model to base state — even on exception.
|
||||||
|
# Without this, LoRA wrappers would persist in the cached model and
|
||||||
|
# subsequent training runs would apply LoRA on top of existing LoRA.
|
||||||
|
dit.eval()
|
||||||
|
_unapply_lora(dit)
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(get_offload_device())
|
||||||
|
diffusion.conditioner.to(get_offload_device())
|
||||||
|
diffusion.pretransform.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
return (output_path,)
|
||||||
+19
-1
@@ -18,6 +18,7 @@ class PrismAudioSampler:
|
|||||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
||||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
||||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||||
|
"sync_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.05, "tooltip": "Scale factor for sync conditioning. Higher values tighten audio-visual sync at the cost of audio naturalness; 0.0 disables sync guidance entirely."}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -27,7 +28,7 @@ class PrismAudioSampler:
|
|||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
def generate(self, model, features, duration, steps, cfg_scale, sync_strength, seed):
|
||||||
device = get_device()
|
device = get_device()
|
||||||
dtype = model["dtype"]
|
dtype = model["dtype"]
|
||||||
strategy = model["strategy"]
|
strategy = model["strategy"]
|
||||||
@@ -43,6 +44,16 @@ class PrismAudioSampler:
|
|||||||
# Compute latent dimensions
|
# Compute latent dimensions
|
||||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||||
|
|
||||||
|
# Sync temporal coverage diagnostic
|
||||||
|
sync_frames = features["sync_features"].shape[0]
|
||||||
|
sync_duration_covered = sync_frames / 25.0 # Synchformer always extracts at 25fps
|
||||||
|
print(f"[PrismAudio] sync: {sync_frames} frames @ 25fps = {sync_duration_covered:.2f}s | "
|
||||||
|
f"audio target: {latent_length} latent frames = {duration:.2f}s", flush=True)
|
||||||
|
if abs(sync_duration_covered - duration) > 0.5:
|
||||||
|
print(f"[PrismAudio] Warning: sync coverage ({sync_duration_covered:.2f}s) differs from "
|
||||||
|
f"audio duration ({duration:.2f}s) by more than 0.5s — consider re-extracting features "
|
||||||
|
f"with the correct video duration.", flush=True)
|
||||||
|
|
||||||
# Note: no seq length config needed — the model adapts to input tensor shapes
|
# Note: no seq length config needed — the model adapts to input tensor shapes
|
||||||
# dynamically via its transformer architecture.
|
# dynamically via its transformer architecture.
|
||||||
|
|
||||||
@@ -76,6 +87,13 @@ class PrismAudioSampler:
|
|||||||
if not has_video:
|
if not has_video:
|
||||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||||
|
|
||||||
|
# Scale sync conditioning after the conditioner MLP (clean linear scale,
|
||||||
|
# avoids SiLU nonlinearity in Sync_MLP). The CFG null path always uses zeros,
|
||||||
|
# so this directly scales the sync guidance magnitude: cfg_scale * (strength*cond - 0).
|
||||||
|
# Only applied when video is present — T2A uses learned empty_sync_feat, not raw sync.
|
||||||
|
if has_video and sync_strength != 1.0 and 'sync_features' in conditioning:
|
||||||
|
conditioning['sync_features'][0] = conditioning['sync_features'][0] * sync_strength
|
||||||
|
|
||||||
# Assemble conditioning inputs for the DiT
|
# Assemble conditioning inputs for the DiT
|
||||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||||
|
|
||||||
|
|||||||
@@ -9,3 +9,4 @@ descript-audio-codec
|
|||||||
vector-quantize-pytorch
|
vector-quantize-pytorch
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
torchaudio
|
||||||
|
|||||||
@@ -85,12 +85,13 @@ def main():
|
|||||||
duration = total_frames / fps
|
duration = total_frames / fps
|
||||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
clip_frames = all_frames[clip_indices]
|
clip_frames = all_frames[clip_indices]
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
sync_frames = all_frames[sync_indices]
|
sync_frames = all_frames[sync_indices]
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
@@ -102,12 +103,13 @@ def main():
|
|||||||
duration = total_frames / fps
|
duration = total_frames / fps
|
||||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user