import os import json import torch import folder_paths import comfy.model_management as mm import comfy.utils from .utils import ( PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder, get_device, get_offload_device, determine_precision, determine_offload_strategy, soft_empty_cache, resolve_hf_token, ) # HuggingFace repo for auto-download HF_REPO_ID = "FunAudioLLM/PrismAudio" REQUIRED_FILES = { "diffusion": "prismaudio.ckpt", "vae": "vae.ckpt", "synchformer": "synchformer_state_dict.pth", } def _download_if_missing(filename, model_dir, hf_token=None): """Download a model file from HuggingFace if not present locally.""" filepath = os.path.join(model_dir, filename) if os.path.exists(filepath): return filepath from huggingface_hub import hf_hub_download print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...") try: downloaded = hf_hub_download( repo_id=HF_REPO_ID, filename=filename, local_dir=model_dir, token=hf_token or None, ) return downloaded except Exception as e: if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower(): raise RuntimeError( f"[PrismAudio] Model '{filename}' requires license acceptance. " f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, " f"then set HF_TOKEN env var or run: huggingface-cli login" ) from e raise class PrismAudioModelLoader: @classmethod def INPUT_TYPES(cls): register_model_folder() return { "required": { "precision": (["auto", "fp32", "fp16", "bf16"],), "offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],), }, } RETURN_TYPES = ("PRISMAUDIO_MODEL",) RETURN_NAMES = ("model",) FUNCTION = "load_model" CATEGORY = PRISMAUDIO_CATEGORY def load_model(self, precision, offload_strategy): device = get_device() dtype = determine_precision(precision, device) strategy = determine_offload_strategy(offload_strategy) token = resolve_hf_token() model_dir = get_prismaudio_model_dir() # Auto-download missing files for key, filename in REQUIRED_FILES.items(): _download_if_missing(filename, model_dir, hf_token=token) # Load config config_path = os.path.join( os.path.dirname(os.path.dirname(__file__)), "prismaudio_core", "configs", "prismaudio.json" ) with open(config_path) as f: model_config = json.load(f) # Create model from config from prismaudio_core.factory import create_model_from_config model = create_model_from_config(model_config) # Load diffusion weights diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"]) diffusion_state = comfy.utils.load_torch_file(diffusion_path) # Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...} if "state_dict" in diffusion_state: diffusion_state = diffusion_state["state_dict"] diff_result = model.load_state_dict(diffusion_state, strict=False) print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True) print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True) if diff_result.missing_keys: print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True) if diff_result.unexpected_keys: print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True) # Sample a few ckpt keys to verify prefix alignment sample_keys = list(diffusion_state.keys())[:5] print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True) # Load VAE weights separately # Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"]) vae_full_state = comfy.utils.load_torch_file(vae_path) print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True) # Sample raw keys to see actual prefix vae_sample_keys = list(vae_full_state.keys())[:8] print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True) # Strip "autoencoder." prefix from keys vae_state = {} prefix = "autoencoder." for k, v in vae_full_state.items(): if k.startswith(prefix): vae_state[k[len(prefix):]] = v else: vae_state[k] = v print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True) # Sample model keys to compare model_vae_keys = list(model.pretransform.state_dict().keys())[:5] print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True) # strict=False: vae.ckpt is a training checkpoint that also contains # discriminator, loss modules, and EMA wrappers not present in the # inference AudioAutoencoder — ignore those extra keys. vae_result = model.pretransform.load_state_dict(vae_state, strict=False) print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True) if vae_result.missing_keys: print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True) # Apply precision: DiT + conditioners in user-selected dtype, # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16 model.model.to(dtype) # DiTWrapper model.conditioner.to(dtype) # MultiConditioner # model.pretransform stays in fp32 if strategy == "keep_in_vram": model = model.to(device) else: model = model.to(get_offload_device()) model.eval() return ({ "model": model, "dtype": dtype, "strategy": strategy, "config": model_config, "model_dir": model_dir, },)