debug: log weight load stats for diffusion and VAE checkpoints
Print key counts, missing/unexpected keys, and sample key names to diagnose whether weights are actually loading correctly (strict=False silently hides mismatches that would cause garbage audio output). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+22
-2
@@ -91,12 +91,25 @@ class PrismAudioModelLoader:
|
|||||||
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
||||||
if "state_dict" in diffusion_state:
|
if "state_dict" in diffusion_state:
|
||||||
diffusion_state = diffusion_state["state_dict"]
|
diffusion_state = diffusion_state["state_dict"]
|
||||||
model.load_state_dict(diffusion_state, strict=False)
|
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
||||||
|
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
||||||
|
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
||||||
|
if diff_result.missing_keys:
|
||||||
|
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
||||||
|
if diff_result.unexpected_keys:
|
||||||
|
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
||||||
|
# Sample a few ckpt keys to verify prefix alignment
|
||||||
|
sample_keys = list(diffusion_state.keys())[:5]
|
||||||
|
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
||||||
|
|
||||||
# Load VAE weights separately
|
# Load VAE weights separately
|
||||||
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
||||||
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
||||||
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
||||||
|
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
||||||
|
# Sample raw keys to see actual prefix
|
||||||
|
vae_sample_keys = list(vae_full_state.keys())[:8]
|
||||||
|
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
||||||
# Strip "autoencoder." prefix from keys
|
# Strip "autoencoder." prefix from keys
|
||||||
vae_state = {}
|
vae_state = {}
|
||||||
prefix = "autoencoder."
|
prefix = "autoencoder."
|
||||||
@@ -105,10 +118,17 @@ class PrismAudioModelLoader:
|
|||||||
vae_state[k[len(prefix):]] = v
|
vae_state[k[len(prefix):]] = v
|
||||||
else:
|
else:
|
||||||
vae_state[k] = v
|
vae_state[k] = v
|
||||||
|
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
||||||
|
# Sample model keys to compare
|
||||||
|
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
||||||
|
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
||||||
# strict=False: vae.ckpt is a training checkpoint that also contains
|
# strict=False: vae.ckpt is a training checkpoint that also contains
|
||||||
# discriminator, loss modules, and EMA wrappers not present in the
|
# discriminator, loss modules, and EMA wrappers not present in the
|
||||||
# inference AudioAutoencoder — ignore those extra keys.
|
# inference AudioAutoencoder — ignore those extra keys.
|
||||||
model.pretransform.load_state_dict(vae_state, strict=False)
|
vae_result = model.pretransform.load_state_dict(vae_state, strict=False)
|
||||||
|
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
||||||
|
if vae_result.missing_keys:
|
||||||
|
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
||||||
|
|
||||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||||
|
|||||||
Reference in New Issue
Block a user