fix: load VAE state dict with strict=False

vae.ckpt is a full training checkpoint containing discriminator, STFT
loss modules, and EMA wrappers that are absent from the inference
AudioAutoencoder. strict=False ignores these training-only keys while
still loading all encoder/decoder/bottleneck weights correctly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 19:51:51 +01:00
parent afc7d5b657
commit 8e3ab999f0
+4 -1
View File
@@ -105,7 +105,10 @@ class PrismAudioModelLoader:
vae_state[k[len(prefix):]] = v vae_state[k[len(prefix):]] = v
else: else:
vae_state[k] = v vae_state[k] = v
model.pretransform.load_state_dict(vae_state) # strict=False: vae.ckpt is a training checkpoint that also contains
# discriminator, loss modules, and EMA wrappers not present in the
# inference AudioAutoencoder — ignore those extra keys.
model.pretransform.load_state_dict(vae_state, strict=False)
# Apply precision: DiT + conditioners in user-selected dtype, # Apply precision: DiT + conditioners in user-selected dtype,
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16 # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16