Print model load status to detect missing/unexpected weight keys
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
8
nodes.py
8
nodes.py
@@ -240,7 +240,13 @@ class STARModelLoader:
|
||||
load_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
||||
if "state_dict" in load_dict:
|
||||
load_dict = load_dict["state_dict"]
|
||||
generator.load_state_dict(load_dict, strict=False)
|
||||
ret = generator.load_state_dict(load_dict, strict=False)
|
||||
if ret.missing_keys:
|
||||
print(f"[STAR] WARNING: {len(ret.missing_keys)} missing keys: {ret.missing_keys[:5]}...")
|
||||
if ret.unexpected_keys:
|
||||
print(f"[STAR] WARNING: {len(ret.unexpected_keys)} unexpected keys: {ret.unexpected_keys[:5]}...")
|
||||
if not ret.missing_keys and not ret.unexpected_keys:
|
||||
print("[STAR] Model loaded perfectly (all keys matched)")
|
||||
del load_dict
|
||||
generator = generator.to(device=keep_on, dtype=dtype)
|
||||
generator.eval()
|
||||
|
||||
Reference in New Issue
Block a user