diff --git a/nodes.py b/nodes.py index 8e44778..3740cf3 100644 --- a/nodes.py +++ b/nodes.py @@ -240,7 +240,13 @@ class STARModelLoader: load_dict = torch.load(model_path, map_location="cpu", weights_only=False) if "state_dict" in load_dict: load_dict = load_dict["state_dict"] - generator.load_state_dict(load_dict, strict=False) + ret = generator.load_state_dict(load_dict, strict=False) + if ret.missing_keys: + print(f"[STAR] WARNING: {len(ret.missing_keys)} missing keys: {ret.missing_keys[:5]}...") + if ret.unexpected_keys: + print(f"[STAR] WARNING: {len(ret.unexpected_keys)} unexpected keys: {ret.unexpected_keys[:5]}...") + if not ret.missing_keys and not ret.unexpected_keys: + print("[STAR] Model loaded perfectly (all keys matched)") del load_dict generator = generator.to(device=keep_on, dtype=dtype) generator.eval()