Print model load status to detect missing/unexpected weight keys

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 02:34:26 +01:00
parent 162c27d243
commit 45e57f58a0

View File

@@ -240,7 +240,13 @@ class STARModelLoader:
load_dict = torch.load(model_path, map_location="cpu", weights_only=False) load_dict = torch.load(model_path, map_location="cpu", weights_only=False)
if "state_dict" in load_dict: if "state_dict" in load_dict:
load_dict = load_dict["state_dict"] load_dict = load_dict["state_dict"]
generator.load_state_dict(load_dict, strict=False) ret = generator.load_state_dict(load_dict, strict=False)
if ret.missing_keys:
print(f"[STAR] WARNING: {len(ret.missing_keys)} missing keys: {ret.missing_keys[:5]}...")
if ret.unexpected_keys:
print(f"[STAR] WARNING: {len(ret.unexpected_keys)} unexpected keys: {ret.unexpected_keys[:5]}...")
if not ret.missing_keys and not ret.unexpected_keys:
print("[STAR] Model loaded perfectly (all keys matched)")
del load_dict del load_dict
generator = generator.to(device=keep_on, dtype=dtype) generator = generator.to(device=keep_on, dtype=dtype)
generator.eval() generator.eval()