Fix save node meta tensor handling with multi-source weight loading
Try pipeline["sd"] first (merged base+VACE weights), then diffusion_model state dict, then reload from checkpoint file. Applies LoRA patches from model patcher on top. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+65
-9
@@ -3,9 +3,8 @@ import json
|
||||
import logging
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from safetensors.torch import save_file
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
|
||||
log = logging.getLogger("ComfyUI-WanVideoSaveMerged")
|
||||
|
||||
@@ -74,16 +73,73 @@ class WanVideoSaveMergedModel:
|
||||
metadata["merged_loras"] = json.dumps(lora_entries)
|
||||
metadata["save_dtype"] = save_dtype
|
||||
|
||||
# Extract state dict from the diffusion model (keys are already bare,
|
||||
# e.g. "blocks.0.self_attn.k.weight" — matching original checkpoint format)
|
||||
# Extract state dict from the diffusion model.
|
||||
# WanVideo wrapper initializes models on meta device (shape-only, no data)
|
||||
# and stores the real weights in pipeline["sd"] after merging base + VACE.
|
||||
# We try multiple sources to get real (non-meta) weights:
|
||||
# 1. pipeline["sd"] — merged state dict kept by the wrapper (includes VACE)
|
||||
# 2. diffusion_model.state_dict() — works if model was loaded to real device
|
||||
# 3. Reload from checkpoint file via base_path — fallback, base weights only
|
||||
diffusion_model = model.model.diffusion_model
|
||||
pipeline = model.model.pipeline
|
||||
|
||||
# Force ComfyUI to load model weights into real memory.
|
||||
# Without this, weights stay on meta device (shape-only, no data).
|
||||
log.info("Loading model weights into memory for saving...")
|
||||
comfy.model_management.load_models_gpu([model], force_full_load=True)
|
||||
state_dict = None
|
||||
|
||||
# Source 1: pipeline["sd"] — the merged (base + VACE + LoRA) state dict
|
||||
pipeline_sd = pipeline.get("sd")
|
||||
if pipeline_sd and isinstance(pipeline_sd, dict) and len(pipeline_sd) > 0:
|
||||
has_meta = any(
|
||||
hasattr(v, "device") and v.device.type == "meta"
|
||||
for v in pipeline_sd.values()
|
||||
if isinstance(v, torch.Tensor)
|
||||
)
|
||||
if not has_meta:
|
||||
log.info("Using merged state dict from pipeline (includes VACE weights)")
|
||||
state_dict = pipeline_sd
|
||||
|
||||
# Source 2: diffusion_model.state_dict()
|
||||
if state_dict is None:
|
||||
sd = diffusion_model.state_dict()
|
||||
has_meta = any(v.device.type == "meta" for v in sd.values())
|
||||
if not has_meta:
|
||||
log.info("Using state dict from diffusion model")
|
||||
state_dict = sd
|
||||
else:
|
||||
del sd
|
||||
|
||||
# Source 3: reload from checkpoint file on disk
|
||||
if state_dict is None:
|
||||
base_path = pipeline.get("base_path") or ""
|
||||
if not base_path or not os.path.exists(base_path):
|
||||
# Search ComfyUI model directories
|
||||
name = str(model_name)
|
||||
for folder_type in ("diffusion_models", "unet", "checkpoints"):
|
||||
try:
|
||||
base_path = folder_paths.get_full_path(folder_type, name)
|
||||
except Exception:
|
||||
base_path = None
|
||||
if base_path and os.path.exists(base_path):
|
||||
break
|
||||
base_path = None
|
||||
|
||||
if not base_path:
|
||||
raise RuntimeError(
|
||||
f"Model weights are on meta device and cannot find checkpoint file "
|
||||
f"'{model_name}'. Ensure the model file is accessible."
|
||||
)
|
||||
|
||||
log.info(f"Weights on meta device — loading from checkpoint: {base_path}")
|
||||
log.warning("Loading from base checkpoint only — VACE weights may not be included. "
|
||||
"For full merged save, ensure the model loader keeps pipeline['sd'].")
|
||||
state_dict = load_torch_file(base_path, device="cpu")
|
||||
|
||||
# Apply any LoRA patches from the model patcher
|
||||
if hasattr(model, "patches") and model.patches:
|
||||
log.info(f"Applying {len(model.patches)} LoRA patches...")
|
||||
for key, patches in model.patches.items():
|
||||
if key in state_dict:
|
||||
state_dict[key] = model.calculate_weight(patches, state_dict[key], key)
|
||||
|
||||
state_dict = diffusion_model.state_dict()
|
||||
|
||||
target_dtype = dtype_map.get(save_dtype)
|
||||
pbar = ProgressBar(len(state_dict))
|
||||
|
||||
Reference in New Issue
Block a user