Print model load status to detect missing/unexpected weight keys
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
8
nodes.py
8
nodes.py
@@ -240,7 +240,13 @@ class STARModelLoader:
|
|||||||
load_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
load_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
||||||
if "state_dict" in load_dict:
|
if "state_dict" in load_dict:
|
||||||
load_dict = load_dict["state_dict"]
|
load_dict = load_dict["state_dict"]
|
||||||
generator.load_state_dict(load_dict, strict=False)
|
ret = generator.load_state_dict(load_dict, strict=False)
|
||||||
|
if ret.missing_keys:
|
||||||
|
print(f"[STAR] WARNING: {len(ret.missing_keys)} missing keys: {ret.missing_keys[:5]}...")
|
||||||
|
if ret.unexpected_keys:
|
||||||
|
print(f"[STAR] WARNING: {len(ret.unexpected_keys)} unexpected keys: {ret.unexpected_keys[:5]}...")
|
||||||
|
if not ret.missing_keys and not ret.unexpected_keys:
|
||||||
|
print("[STAR] Model loaded perfectly (all keys matched)")
|
||||||
del load_dict
|
del load_dict
|
||||||
generator = generator.to(device=keep_on, dtype=dtype)
|
generator = generator.to(device=keep_on, dtype=dtype)
|
||||||
generator.eval()
|
generator.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user