From ccc43f520e97bea5c1769d286c69d72bb6476a1a Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 12 Mar 2026 15:32:21 +0100 Subject: [PATCH] Fix save node meta tensor handling with multi-source weight loading Try pipeline["sd"] first (merged base+VACE weights), then diffusion_model state dict, then reload from checkpoint file. Applies LoRA patches from model patcher on top. Co-Authored-By: Claude Opus 4.6 --- save_node.py | 74 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/save_node.py b/save_node.py index f56d6ac..62fab74 100644 --- a/save_node.py +++ b/save_node.py @@ -3,9 +3,8 @@ import json import logging import torch import folder_paths -import comfy.model_management from safetensors.torch import save_file -from comfy.utils import ProgressBar +from comfy.utils import ProgressBar, load_torch_file log = logging.getLogger("ComfyUI-WanVideoSaveMerged") @@ -74,16 +73,73 @@ class WanVideoSaveMergedModel: metadata["merged_loras"] = json.dumps(lora_entries) metadata["save_dtype"] = save_dtype - # Extract state dict from the diffusion model (keys are already bare, - # e.g. "blocks.0.self_attn.k.weight" — matching original checkpoint format) + # Extract state dict from the diffusion model. + # WanVideo wrapper initializes models on meta device (shape-only, no data) + # and stores the real weights in pipeline["sd"] after merging base + VACE. + # We try multiple sources to get real (non-meta) weights: + # 1. pipeline["sd"] — merged state dict kept by the wrapper (includes VACE) + # 2. diffusion_model.state_dict() — works if model was loaded to real device + # 3. Reload from checkpoint file via base_path — fallback, base weights only diffusion_model = model.model.diffusion_model + pipeline = model.model.pipeline - # Force ComfyUI to load model weights into real memory. - # Without this, weights stay on meta device (shape-only, no data). - log.info("Loading model weights into memory for saving...") - comfy.model_management.load_models_gpu([model], force_full_load=True) + state_dict = None + + # Source 1: pipeline["sd"] — the merged (base + VACE + LoRA) state dict + pipeline_sd = pipeline.get("sd") + if pipeline_sd and isinstance(pipeline_sd, dict) and len(pipeline_sd) > 0: + has_meta = any( + hasattr(v, "device") and v.device.type == "meta" + for v in pipeline_sd.values() + if isinstance(v, torch.Tensor) + ) + if not has_meta: + log.info("Using merged state dict from pipeline (includes VACE weights)") + state_dict = pipeline_sd + + # Source 2: diffusion_model.state_dict() + if state_dict is None: + sd = diffusion_model.state_dict() + has_meta = any(v.device.type == "meta" for v in sd.values()) + if not has_meta: + log.info("Using state dict from diffusion model") + state_dict = sd + else: + del sd + + # Source 3: reload from checkpoint file on disk + if state_dict is None: + base_path = pipeline.get("base_path") or "" + if not base_path or not os.path.exists(base_path): + # Search ComfyUI model directories + name = str(model_name) + for folder_type in ("diffusion_models", "unet", "checkpoints"): + try: + base_path = folder_paths.get_full_path(folder_type, name) + except Exception: + base_path = None + if base_path and os.path.exists(base_path): + break + base_path = None + + if not base_path: + raise RuntimeError( + f"Model weights are on meta device and cannot find checkpoint file " + f"'{model_name}'. Ensure the model file is accessible." + ) + + log.info(f"Weights on meta device — loading from checkpoint: {base_path}") + log.warning("Loading from base checkpoint only — VACE weights may not be included. " + "For full merged save, ensure the model loader keeps pipeline['sd'].") + state_dict = load_torch_file(base_path, device="cpu") + + # Apply any LoRA patches from the model patcher + if hasattr(model, "patches") and model.patches: + log.info(f"Applying {len(model.patches)} LoRA patches...") + for key, patches in model.patches.items(): + if key in state_dict: + state_dict[key] = model.calculate_weight(patches, state_dict[key], key) - state_dict = diffusion_model.state_dict() target_dtype = dtype_map.get(save_dtype) pbar = ProgressBar(len(state_dict))