diff --git a/nodes/model_loader.py b/nodes/model_loader.py index e0ef4c1..8daab2d 100644 --- a/nodes/model_loader.py +++ b/nodes/model_loader.py @@ -91,12 +91,25 @@ class PrismAudioModelLoader: # Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...} if "state_dict" in diffusion_state: diffusion_state = diffusion_state["state_dict"] - model.load_state_dict(diffusion_state, strict=False) + diff_result = model.load_state_dict(diffusion_state, strict=False) + print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True) + print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True) + if diff_result.missing_keys: + print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True) + if diff_result.unexpected_keys: + print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True) + # Sample a few ckpt keys to verify prefix alignment + sample_keys = list(diffusion_state.keys())[:5] + print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True) # Load VAE weights separately # Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"]) vae_full_state = comfy.utils.load_torch_file(vae_path) + print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True) + # Sample raw keys to see actual prefix + vae_sample_keys = list(vae_full_state.keys())[:8] + print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True) # Strip "autoencoder." prefix from keys vae_state = {} prefix = "autoencoder." @@ -105,10 +118,17 @@ class PrismAudioModelLoader: vae_state[k[len(prefix):]] = v else: vae_state[k] = v + print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True) + # Sample model keys to compare + model_vae_keys = list(model.pretransform.state_dict().keys())[:5] + print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True) # strict=False: vae.ckpt is a training checkpoint that also contains # discriminator, loss modules, and EMA wrappers not present in the # inference AudioAutoencoder — ignore those extra keys. - model.pretransform.load_state_dict(vae_state, strict=False) + vae_result = model.pretransform.load_state_dict(vae_state, strict=False) + print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True) + if vae_result.missing_keys: + print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True) # Apply precision: DiT + conditioners in user-selected dtype, # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16