chore: remove all PrismAudio code from main branch
- Delete prismaudio_core/, data_utils/, scripts/, docs/plans/ - Delete PrismAudio nodes (feature_extractor, feature_loader, model_loader, sampler, text_only) - Delete PrismAudio workflows (video_to_audio, text_to_audio) - Clean nodes/utils.py: rename PRISMAUDIO_CATEGORY → SELVA_CATEGORY, remove unused helpers - Strip PrismAudio-only deps from requirements.txt Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,207 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
from .feature_loader import PrismAudioFeatureLoader
|
||||
|
||||
# Managed venv created automatically when python_env is left as default
|
||||
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
||||
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
||||
|
||||
_EXTRACT_PACKAGES = [
|
||||
"torch", "torchaudio", "torchvision",
|
||||
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
||||
"tensorflow-cpu>=2.16.0",
|
||||
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
|
||||
"jax[cuda13]", "flax",
|
||||
"transformers", "decord", "einops", "numpy", "mediapy",
|
||||
"git+https://github.com/google-deepmind/videoprism.git",
|
||||
]
|
||||
|
||||
|
||||
def _pip_install(pip, *packages, label=None):
|
||||
"""Install one or more packages with visible output; raise on failure."""
|
||||
tag = label or packages[0]
|
||||
print(f"[PrismAudio] installing {tag} ...", flush=True)
|
||||
result = subprocess.run(
|
||||
[pip, "install", "--progress-bar", "on"] + list(packages),
|
||||
capture_output=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
|
||||
"See pip output above for details."
|
||||
)
|
||||
print(f"[PrismAudio] {tag} OK", flush=True)
|
||||
|
||||
|
||||
def _ensure_extract_env():
|
||||
"""Create and populate the managed venv on first use."""
|
||||
if os.path.exists(_MANAGED_PYTHON):
|
||||
return _MANAGED_PYTHON
|
||||
|
||||
import shutil
|
||||
if os.path.exists(_MANAGED_VENV):
|
||||
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
|
||||
shutil.rmtree(_MANAGED_VENV)
|
||||
|
||||
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
|
||||
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
|
||||
|
||||
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
|
||||
|
||||
print("[PrismAudio] Upgrading pip...", flush=True)
|
||||
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
|
||||
|
||||
total = len(_EXTRACT_PACKAGES)
|
||||
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
|
||||
|
||||
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
|
||||
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
|
||||
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
|
||||
_pip_install(pip, pkg, label=label)
|
||||
|
||||
print("[PrismAudio] Feature-extraction env ready.", flush=True)
|
||||
return _MANAGED_PYTHON
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, cot_text):
|
||||
"""Create a hash of the inputs for caching."""
|
||||
h = hashlib.sha256()
|
||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||
h.update(cot_text.encode())
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def _save_frames_to_npy(video_tensor, output_path):
|
||||
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
|
||||
|
||||
Lossless — avoids H.264 encode/decode roundtrip.
|
||||
"""
|
||||
import numpy as np
|
||||
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
||||
np.save(output_path, frames_np)
|
||||
|
||||
|
||||
class PrismAudioFeatureExtractor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": ("IMAGE",),
|
||||
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
||||
},
|
||||
"optional": {
|
||||
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
|
||||
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
|
||||
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
||||
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
|
||||
RETURN_NAMES = ("features", "fps")
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
|
||||
# Resolve fps from VHS video_info if connected
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
# Resolve python binary
|
||||
if python_env == "comfyui_env":
|
||||
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
||||
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
|
||||
python_bin = sys.executable
|
||||
else:
|
||||
python_bin = _ensure_extract_env()
|
||||
|
||||
# Determine cache directory
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Check cache
|
||||
cache_hash = _hash_inputs(video, caption_cot)
|
||||
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||
if os.path.exists(cached_path):
|
||||
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||
loader = PrismAudioFeatureLoader()
|
||||
features, = loader.load_features(cached_path)
|
||||
return (features, float(fps))
|
||||
|
||||
# Save frames to temp file (lossless .npy, no codec roundtrip)
|
||||
import time
|
||||
t0 = time.perf_counter()
|
||||
frames = video.shape[0]
|
||||
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
|
||||
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
|
||||
tmp_video = tmp.name
|
||||
_save_frames_to_npy(video, tmp_video)
|
||||
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||
|
||||
# Build subprocess command
|
||||
script_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"scripts", "extract_features.py"
|
||||
)
|
||||
|
||||
import folder_paths
|
||||
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
|
||||
if not os.path.exists(synchformer_ckpt):
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
|
||||
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
|
||||
)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
script_path,
|
||||
"--video", tmp_video,
|
||||
"--cot_text", caption_cot,
|
||||
"--output", cached_path,
|
||||
"--source_fps", str(fps),
|
||||
"--synchformer_ckpt", synchformer_ckpt,
|
||||
]
|
||||
|
||||
# Build env: inherit current env, inject HF token if provided
|
||||
import copy
|
||||
env = copy.copy(os.environ)
|
||||
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
|
||||
if token:
|
||||
env["HF_TOKEN"] = token
|
||||
env["HUGGING_FACE_HUB_TOKEN"] = token
|
||||
else:
|
||||
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
|
||||
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
|
||||
|
||||
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
|
||||
try:
|
||||
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=False,
|
||||
timeout=600, # 10 minute timeout
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
|
||||
"See output above for details."
|
||||
)
|
||||
print("[PrismAudio] Feature extraction subprocess finished successfully.")
|
||||
finally:
|
||||
if os.path.exists(tmp_video):
|
||||
os.unlink(tmp_video)
|
||||
|
||||
# Load the extracted features
|
||||
loader = PrismAudioFeatureLoader()
|
||||
features, = loader.load_features(cached_path)
|
||||
return (features, float(fps))
|
||||
@@ -1,53 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
|
||||
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
||||
# global_video_features and global_text_features are NOT consumed by any conditioner
|
||||
# in the prismaudio.json config — they are unused.
|
||||
REQUIRED_KEYS = [
|
||||
"video_features",
|
||||
"text_features",
|
||||
"sync_features",
|
||||
]
|
||||
|
||||
|
||||
class PrismAudioFeatureLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
||||
RETURN_NAMES = ("features",)
|
||||
FUNCTION = "load_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_features(self, npz_path):
|
||||
if not os.path.exists(npz_path):
|
||||
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
||||
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
|
||||
features = {}
|
||||
for key in REQUIRED_KEYS:
|
||||
if key in data:
|
||||
features[key] = torch.from_numpy(data[key]).float()
|
||||
else:
|
||||
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
||||
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
||||
if key == "sync_features":
|
||||
features[key] = torch.zeros(8, 768)
|
||||
else:
|
||||
features[key] = torch.zeros(1, 1024)
|
||||
|
||||
# Load duration if present
|
||||
if "duration" in data:
|
||||
features["duration"] = float(data["duration"])
|
||||
|
||||
return (features,)
|
||||
@@ -1,154 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
||||
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
||||
soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
|
||||
# HuggingFace repo for auto-download
|
||||
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
||||
REQUIRED_FILES = {
|
||||
"diffusion": "prismaudio.ckpt",
|
||||
"vae": "vae.ckpt",
|
||||
"synchformer": "synchformer_state_dict.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(filename, model_dir, hf_token=None):
|
||||
"""Download a model file from HuggingFace if not present locally."""
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
if os.path.exists(filepath):
|
||||
return filepath
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
||||
try:
|
||||
downloaded = hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=filename,
|
||||
local_dir=model_dir,
|
||||
token=hf_token or None,
|
||||
)
|
||||
return downloaded
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
||||
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
||||
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
class PrismAudioModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
register_model_folder()
|
||||
return {
|
||||
"required": {
|
||||
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_model(self, precision, offload_strategy):
|
||||
device = get_device()
|
||||
dtype = determine_precision(precision, device)
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
token = resolve_hf_token()
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
|
||||
# Auto-download missing files
|
||||
for key, filename in REQUIRED_FILES.items():
|
||||
_download_if_missing(filename, model_dir, hf_token=token)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"prismaudio_core", "configs", "prismaudio.json"
|
||||
)
|
||||
with open(config_path) as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
# Create model from config
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
model = create_model_from_config(model_config)
|
||||
|
||||
# Load diffusion weights
|
||||
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
||||
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
||||
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
||||
if "state_dict" in diffusion_state:
|
||||
diffusion_state = diffusion_state["state_dict"]
|
||||
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
||||
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
||||
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
||||
if diff_result.missing_keys:
|
||||
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
||||
if diff_result.unexpected_keys:
|
||||
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
||||
# Sample a few ckpt keys to verify prefix alignment
|
||||
sample_keys = list(diffusion_state.keys())[:5]
|
||||
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
||||
|
||||
# Load VAE weights separately
|
||||
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
||||
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
||||
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
||||
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
||||
# Sample raw keys to see actual prefix
|
||||
vae_sample_keys = list(vae_full_state.keys())[:8]
|
||||
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
||||
# Strip "autoencoder." prefix from keys
|
||||
vae_state = {}
|
||||
prefix = "autoencoder."
|
||||
for k, v in vae_full_state.items():
|
||||
if k.startswith(prefix):
|
||||
vae_state[k[len(prefix):]] = v
|
||||
else:
|
||||
vae_state[k] = v
|
||||
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
||||
# Sample model keys to compare
|
||||
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
||||
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
||||
# strict=False: vae.ckpt is a training checkpoint that also contains
|
||||
# discriminator, loss modules, and EMA wrappers not present in the
|
||||
# inference AudioAutoencoder — ignore those extra keys.
|
||||
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
|
||||
# (AutoencoderPretransform.load_state_dict doesn't return the result)
|
||||
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
|
||||
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
||||
if vae_result.missing_keys:
|
||||
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
||||
|
||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||
model.model.to(dtype) # DiTWrapper
|
||||
model.conditioner.to(dtype) # MultiConditioner
|
||||
# model.pretransform stays in fp32
|
||||
|
||||
if strategy == "keep_in_vram":
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.to(get_offload_device())
|
||||
|
||||
model.eval()
|
||||
|
||||
return ({
|
||||
"model": model,
|
||||
"dtype": dtype,
|
||||
"strategy": strategy,
|
||||
"config": model_config,
|
||||
"model_dir": model_dir,
|
||||
},)
|
||||
@@ -1,165 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||
get_device, get_offload_device, soft_empty_cache,
|
||||
)
|
||||
|
||||
|
||||
class PrismAudioSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"features": ("PRISMAUDIO_FEATURES",),
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
diffusion = model["model"]
|
||||
|
||||
# Resolve duration: 0 means use video duration from features
|
||||
if duration <= 0:
|
||||
if "duration" not in features:
|
||||
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
|
||||
duration = features["duration"]
|
||||
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
|
||||
|
||||
# Compute latent dimensions
|
||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||
|
||||
# Note: no seq length config needed — the model adapts to input tensor shapes
|
||||
# dynamically via its transformer architecture.
|
||||
|
||||
# Determine if video features are present (not all zeros)
|
||||
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
||||
|
||||
video_feat = features["video_features"].to(device, dtype=dtype)
|
||||
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
||||
|
||||
# Build metadata as a TUPLE of dicts (one per batch sample)
|
||||
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
||||
sample_meta = {
|
||||
"video_features": video_feat,
|
||||
"text_features": features["text_features"].to(device, dtype=dtype),
|
||||
"sync_features": sync_feat,
|
||||
"video_exist": torch.tensor(has_video),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
# Move model to device if offloaded
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
# Run conditioning
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
|
||||
# Handle missing video: substitute learned empty embeddings
|
||||
if not has_video:
|
||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||
|
||||
# Assemble conditioning inputs for the DiT
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
[1, IO_CHANNELS, latent_length],
|
||||
generator=generator,
|
||||
device=gen_device,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
# Sample with progress bar
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||
|
||||
def on_step(info):
|
||||
pbar.update(1)
|
||||
|
||||
fakes = sample_discrete_euler(
|
||||
diffusion.model,
|
||||
noise,
|
||||
steps,
|
||||
callback=on_step,
|
||||
**cond_inputs,
|
||||
cfg_scale=cfg_scale,
|
||||
batch_cfg=True,
|
||||
)
|
||||
|
||||
fakes_f = fakes.float()
|
||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||
|
||||
# Offload diffusion model and conditioner before VAE decode
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||
audio = diffusion.pretransform.decode(fakes_f)
|
||||
|
||||
# Offload VAE
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Peak normalize then clamp (matching reference: div by max abs before clamp)
|
||||
audio = audio.float()
|
||||
pre_norm_std = audio.std().item()
|
||||
pre_norm_peak = audio.abs().max().item()
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||
|
||||
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
|
||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||
|
||||
|
||||
def _substitute_empty_features(diffusion, conditioning, device, dtype):
|
||||
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
|
||||
|
||||
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
|
||||
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
|
||||
near-zero activations, NOT the learned null signal the model was trained with.
|
||||
|
||||
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
|
||||
"""
|
||||
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
|
||||
|
||||
# Substitute video_features with learned empty_clip_feat
|
||||
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
|
||||
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
|
||||
batch_size = conditioning['video_features'][0].shape[0]
|
||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||
conditioning['video_features'][0] = empty_expanded
|
||||
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||
|
||||
# Substitute sync_features with learned empty_sync_feat
|
||||
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
|
||||
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
|
||||
batch_size = conditioning['sync_features'][0].shape[0]
|
||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||
conditioning['sync_features'][0] = empty_expanded
|
||||
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||
_CLIP_SIZE = 384
|
||||
@@ -68,7 +68,7 @@ class SelvaFeatureExtractor:
|
||||
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
|
||||
RETURN_NAMES = ("features", "fps", "prompt")
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir=""):
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy
|
||||
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
|
||||
|
||||
# Variant → (generator filename, mode, has_bigvgan)
|
||||
_VARIANTS = {
|
||||
@@ -96,7 +96,7 @@ class SelvaModelLoader:
|
||||
RETURN_TYPES = ("SELVA_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
|
||||
def load_model(self, variant, precision, offload_strategy):
|
||||
from selva_core.model.networks_generator import get_my_mmaudio
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
|
||||
class SelvaSampler:
|
||||
@@ -35,7 +35,7 @@ class SelvaSampler:
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
from .sampler import _substitute_empty_features
|
||||
|
||||
|
||||
class PrismAudioTextOnly:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
|
||||
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
|
||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
diffusion = model["model"]
|
||||
|
||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||
|
||||
# Encode text with T5-Gemma
|
||||
text_features = _encode_text_t5(text_prompt, device, dtype)
|
||||
|
||||
# Build metadata: tuple of one dict per sample
|
||||
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
||||
# These will be substituted with learned empty embeddings after conditioning
|
||||
sample_meta = {
|
||||
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
||||
"text_features": text_features.to(device, dtype=dtype),
|
||||
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
||||
"video_exist": torch.tensor(False),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
|
||||
# Substitute empty features for video/sync
|
||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
[1, IO_CHANNELS, latent_length],
|
||||
generator=generator,
|
||||
device=gen_device,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||
|
||||
def on_step(info):
|
||||
pbar.update(1)
|
||||
|
||||
fakes = sample_discrete_euler(
|
||||
diffusion.model,
|
||||
noise,
|
||||
steps,
|
||||
callback=on_step,
|
||||
**cond_inputs,
|
||||
cfg_scale=cfg_scale,
|
||||
batch_cfg=True,
|
||||
)
|
||||
|
||||
fakes_f = fakes.float()
|
||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||
audio = diffusion.pretransform.decode(fakes_f)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Peak normalize then clamp
|
||||
audio = audio.float()
|
||||
pre_norm_std = audio.std().item()
|
||||
pre_norm_peak = audio.abs().max().item()
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
|
||||
|
||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||
|
||||
|
||||
# T5-Gemma encoder singleton
|
||||
_t5_model = None
|
||||
_t5_tokenizer = None
|
||||
|
||||
|
||||
def _encode_text_t5(text, device, dtype):
|
||||
"""Encode text using T5-Gemma.
|
||||
|
||||
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
||||
FeaturesUtils.encode_t5_text() implementation.
|
||||
No truncation applied (matching reference behavior).
|
||||
"""
|
||||
global _t5_model, _t5_tokenizer
|
||||
|
||||
if _t5_model is None:
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
model_id = "google/t5gemma-l-l-ul2-it"
|
||||
token = resolve_hf_token()
|
||||
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
||||
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
||||
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
||||
_t5_model.eval()
|
||||
|
||||
_t5_model.to(device, dtype=dtype)
|
||||
|
||||
tokens = _t5_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = _t5_model(**tokens)
|
||||
|
||||
# Move T5 off GPU after encoding to save VRAM
|
||||
_t5_model.to("cpu")
|
||||
soft_empty_cache()
|
||||
|
||||
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|
||||
+4
-47
@@ -1,21 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
|
||||
PRISMAUDIO_CATEGORY = "PrismAudio"
|
||||
SAMPLE_RATE = 44100
|
||||
DOWNSAMPLING_RATIO = 2048
|
||||
IO_CHANNELS = 64
|
||||
|
||||
def get_prismaudio_model_dir():
|
||||
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
def register_model_folder():
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
folder_paths.add_model_folder_path("prismaudio", model_dir)
|
||||
SELVA_CATEGORY = "SelVA"
|
||||
|
||||
def get_device():
|
||||
return mm.get_torch_device()
|
||||
@@ -23,42 +9,13 @@ def get_device():
|
||||
def get_offload_device():
|
||||
return mm.unet_offload_device()
|
||||
|
||||
def get_free_memory(device=None):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
return mm.get_free_memory(device)
|
||||
|
||||
def soft_empty_cache():
|
||||
mm.soft_empty_cache()
|
||||
|
||||
def determine_precision(preference, device):
|
||||
if preference != "auto":
|
||||
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
|
||||
if device.type == "cpu":
|
||||
return torch.float32
|
||||
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
return torch.bfloat16
|
||||
return torch.float16
|
||||
|
||||
def determine_offload_strategy(preference):
|
||||
if preference != "auto":
|
||||
return preference
|
||||
free_mem = get_free_memory()
|
||||
gb = free_mem / (1024 ** 3)
|
||||
if gb >= 24:
|
||||
free_mem = mm.get_free_memory(get_device())
|
||||
if free_mem / (1024 ** 3) >= 16:
|
||||
return "keep_in_vram"
|
||||
else:
|
||||
return "offload_to_cpu"
|
||||
|
||||
def try_import_flash_attn():
|
||||
try:
|
||||
import flash_attn
|
||||
return flash_attn
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
def resolve_hf_token():
|
||||
env_token = os.environ.get("HF_TOKEN")
|
||||
if env_token:
|
||||
return env_token
|
||||
return None
|
||||
return "offload_to_cpu"
|
||||
|
||||
Reference in New Issue
Block a user