diff --git a/nodes/model_loader.py b/nodes/model_loader.py index 5d92977..e0ef4c1 100644 --- a/nodes/model_loader.py +++ b/nodes/model_loader.py @@ -105,7 +105,10 @@ class PrismAudioModelLoader: vae_state[k[len(prefix):]] = v else: vae_state[k] = v - model.pretransform.load_state_dict(vae_state) + # strict=False: vae.ckpt is a training checkpoint that also contains + # discriminator, loss modules, and EMA wrappers not present in the + # inference AudioAutoencoder — ignore those extra keys. + model.pretransform.load_state_dict(vae_state, strict=False) # Apply precision: DiT + conditioners in user-selected dtype, # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16