11457fc27a
AutoencoderPretransform.load_state_dict() doesn't return IncompatibleKeys. Load into pretransform.model (AudioAutoencoder) to get the return value and see actual missing/unexpected key counts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
155 lines
6.4 KiB
Python
155 lines
6.4 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import folder_paths
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
|
|
from .utils import (
|
|
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
|
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
|
soft_empty_cache, resolve_hf_token,
|
|
)
|
|
|
|
# HuggingFace repo for auto-download
|
|
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
|
REQUIRED_FILES = {
|
|
"diffusion": "prismaudio.ckpt",
|
|
"vae": "vae.ckpt",
|
|
"synchformer": "synchformer_state_dict.pth",
|
|
}
|
|
|
|
|
|
def _download_if_missing(filename, model_dir, hf_token=None):
|
|
"""Download a model file from HuggingFace if not present locally."""
|
|
filepath = os.path.join(model_dir, filename)
|
|
if os.path.exists(filepath):
|
|
return filepath
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
|
try:
|
|
downloaded = hf_hub_download(
|
|
repo_id=HF_REPO_ID,
|
|
filename=filename,
|
|
local_dir=model_dir,
|
|
token=hf_token or None,
|
|
)
|
|
return downloaded
|
|
except Exception as e:
|
|
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
|
raise RuntimeError(
|
|
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
|
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
|
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
|
) from e
|
|
raise
|
|
|
|
|
|
class PrismAudioModelLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
register_model_folder()
|
|
return {
|
|
"required": {
|
|
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
|
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
|
RETURN_NAMES = ("model",)
|
|
FUNCTION = "load_model"
|
|
CATEGORY = PRISMAUDIO_CATEGORY
|
|
|
|
def load_model(self, precision, offload_strategy):
|
|
device = get_device()
|
|
dtype = determine_precision(precision, device)
|
|
strategy = determine_offload_strategy(offload_strategy)
|
|
token = resolve_hf_token()
|
|
model_dir = get_prismaudio_model_dir()
|
|
|
|
# Auto-download missing files
|
|
for key, filename in REQUIRED_FILES.items():
|
|
_download_if_missing(filename, model_dir, hf_token=token)
|
|
|
|
# Load config
|
|
config_path = os.path.join(
|
|
os.path.dirname(os.path.dirname(__file__)),
|
|
"prismaudio_core", "configs", "prismaudio.json"
|
|
)
|
|
with open(config_path) as f:
|
|
model_config = json.load(f)
|
|
|
|
# Create model from config
|
|
from prismaudio_core.factory import create_model_from_config
|
|
model = create_model_from_config(model_config)
|
|
|
|
# Load diffusion weights
|
|
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
|
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
|
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
|
if "state_dict" in diffusion_state:
|
|
diffusion_state = diffusion_state["state_dict"]
|
|
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
|
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
|
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
|
if diff_result.missing_keys:
|
|
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
|
if diff_result.unexpected_keys:
|
|
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
|
# Sample a few ckpt keys to verify prefix alignment
|
|
sample_keys = list(diffusion_state.keys())[:5]
|
|
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
|
|
|
# Load VAE weights separately
|
|
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
|
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
|
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
|
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
|
# Sample raw keys to see actual prefix
|
|
vae_sample_keys = list(vae_full_state.keys())[:8]
|
|
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
|
# Strip "autoencoder." prefix from keys
|
|
vae_state = {}
|
|
prefix = "autoencoder."
|
|
for k, v in vae_full_state.items():
|
|
if k.startswith(prefix):
|
|
vae_state[k[len(prefix):]] = v
|
|
else:
|
|
vae_state[k] = v
|
|
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
|
# Sample model keys to compare
|
|
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
|
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
|
# strict=False: vae.ckpt is a training checkpoint that also contains
|
|
# discriminator, loss modules, and EMA wrappers not present in the
|
|
# inference AudioAutoencoder — ignore those extra keys.
|
|
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
|
|
# (AutoencoderPretransform.load_state_dict doesn't return the result)
|
|
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
|
|
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
|
if vae_result.missing_keys:
|
|
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
|
|
|
# Apply precision: DiT + conditioners in user-selected dtype,
|
|
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
|
model.model.to(dtype) # DiTWrapper
|
|
model.conditioner.to(dtype) # MultiConditioner
|
|
# model.pretransform stays in fp32
|
|
|
|
if strategy == "keep_in_vram":
|
|
model = model.to(device)
|
|
else:
|
|
model = model.to(get_offload_device())
|
|
|
|
model.eval()
|
|
|
|
return ({
|
|
"model": model,
|
|
"dtype": dtype,
|
|
"strategy": strategy,
|
|
"config": model_config,
|
|
"model_dir": model_dir,
|
|
},)
|