feat: PrismAudioModelLoader node with auto-download and adaptive VRAM
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
||||
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
||||
soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
|
||||
# HuggingFace repo for auto-download
|
||||
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
||||
REQUIRED_FILES = {
|
||||
"diffusion": "prismaudio.ckpt",
|
||||
"vae": "vae.ckpt",
|
||||
"synchformer": "synchformer_state_dict.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(filename, model_dir, hf_token=None):
|
||||
"""Download a model file from HuggingFace if not present locally."""
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
if os.path.exists(filepath):
|
||||
return filepath
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
||||
try:
|
||||
downloaded = hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=filename,
|
||||
local_dir=model_dir,
|
||||
token=hf_token or None,
|
||||
)
|
||||
return downloaded
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
||||
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
||||
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
class PrismAudioModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
register_model_folder()
|
||||
return {
|
||||
"required": {
|
||||
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_model(self, precision, offload_strategy):
|
||||
device = get_device()
|
||||
dtype = determine_precision(precision, device)
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
token = resolve_hf_token()
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
|
||||
# Auto-download missing files
|
||||
for key, filename in REQUIRED_FILES.items():
|
||||
_download_if_missing(filename, model_dir, hf_token=token)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"prismaudio_core", "configs", "prismaudio.json"
|
||||
)
|
||||
with open(config_path) as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
# Create model from config
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
model = create_model_from_config(model_config)
|
||||
|
||||
# Load diffusion weights
|
||||
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
||||
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
||||
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
||||
if "state_dict" in diffusion_state:
|
||||
diffusion_state = diffusion_state["state_dict"]
|
||||
model.load_state_dict(diffusion_state, strict=False)
|
||||
|
||||
# Load VAE weights separately
|
||||
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
||||
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
||||
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
||||
# Strip "autoencoder." prefix from keys
|
||||
vae_state = {}
|
||||
prefix = "autoencoder."
|
||||
for k, v in vae_full_state.items():
|
||||
if k.startswith(prefix):
|
||||
vae_state[k[len(prefix):]] = v
|
||||
else:
|
||||
vae_state[k] = v
|
||||
model.pretransform.load_state_dict(vae_state)
|
||||
|
||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||
model.model.to(dtype) # DiTWrapper
|
||||
model.conditioner.to(dtype) # MultiConditioner
|
||||
# model.pretransform stays in fp32
|
||||
|
||||
if strategy == "keep_in_vram":
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.to(get_offload_device())
|
||||
|
||||
model.eval()
|
||||
|
||||
return ({
|
||||
"model": model,
|
||||
"dtype": dtype,
|
||||
"strategy": strategy,
|
||||
"config": model_config,
|
||||
"model_dir": model_dir,
|
||||
},)
|
||||
Reference in New Issue
Block a user