fix: load VAE state dict with strict=False
vae.ckpt is a full training checkpoint containing discriminator, STFT loss modules, and EMA wrappers that are absent from the inference AudioAutoencoder. strict=False ignores these training-only keys while still loading all encoder/decoder/bottleneck weights correctly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -105,7 +105,10 @@ class PrismAudioModelLoader:
|
||||
vae_state[k[len(prefix):]] = v
|
||||
else:
|
||||
vae_state[k] = v
|
||||
model.pretransform.load_state_dict(vae_state)
|
||||
# strict=False: vae.ckpt is a training checkpoint that also contains
|
||||
# discriminator, loss modules, and EMA wrappers not present in the
|
||||
# inference AudioAutoencoder — ignore those extra keys.
|
||||
model.pretransform.load_state_dict(vae_state, strict=False)
|
||||
|
||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||
|
||||
Reference in New Issue
Block a user