debug: log weight load stats for diffusion and VAE checkpoints

Print key counts, missing/unexpected keys, and sample key names to
diagnose whether weights are actually loading correctly (strict=False
silently hides mismatches that would cause garbage audio output).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:53:25 +01:00
parent 83a7f2787b
commit f2705b3063
+22 -2
View File
@@ -91,12 +91,25 @@ class PrismAudioModelLoader:
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
if "state_dict" in diffusion_state:
diffusion_state = diffusion_state["state_dict"]
model.load_state_dict(diffusion_state, strict=False)
diff_result = model.load_state_dict(diffusion_state, strict=False)
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
if diff_result.missing_keys:
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
if diff_result.unexpected_keys:
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
# Sample a few ckpt keys to verify prefix alignment
sample_keys = list(diffusion_state.keys())[:5]
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
# Load VAE weights separately
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
vae_full_state = comfy.utils.load_torch_file(vae_path)
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
# Sample raw keys to see actual prefix
vae_sample_keys = list(vae_full_state.keys())[:8]
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
# Strip "autoencoder." prefix from keys
vae_state = {}
prefix = "autoencoder."
@@ -105,10 +118,17 @@ class PrismAudioModelLoader:
vae_state[k[len(prefix):]] = v
else:
vae_state[k] = v
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
# Sample model keys to compare
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
# strict=False: vae.ckpt is a training checkpoint that also contains
# discriminator, loss modules, and EMA wrappers not present in the
# inference AudioAutoencoder — ignore those extra keys.
model.pretransform.load_state_dict(vae_state, strict=False)
vae_result = model.pretransform.load_state_dict(vae_state, strict=False)
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
if vae_result.missing_keys:
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
# Apply precision: DiT + conditioners in user-selected dtype,
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16