From 855d748330b93036786d74c22a0b38291bb3ce2c Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Tue, 10 Feb 2026 17:46:27 +0100 Subject: [PATCH] node --- __init__.py | 3 ++ nodes.py | 112 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 __init__.py create mode 100644 nodes.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..39a8c6b --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..d0a93fe --- /dev/null +++ b/nodes.py @@ -0,0 +1,112 @@ +import os +import json +import logging +import torch +import folder_paths +from safetensors.torch import save_file +from comfy.utils import ProgressBar + +log = logging.getLogger("ComfyUI-WanVideoSaveMerged") + + +class WanVideoSaveMergedModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL", {"tooltip": "WanVideo model with merged LoRA from the WanVideo Model Loader"}), + "filename_prefix": ("STRING", {"default": "merged_wanvideo", "tooltip": "Filename prefix for the saved model"}), + }, + "optional": { + "save_dtype": (["same", "bf16", "fp16", "fp32"], { + "default": "same", + "tooltip": "Cast weights to this dtype before saving. 'same' keeps the current dtype of each tensor. Recommended to set explicitly if model was loaded in fp8." + }), + "custom_path": ("STRING", { + "default": "", + "tooltip": "Absolute path to save directory. Leave empty to save in ComfyUI/models/diffusion_models/" + }), + }, + } + + RETURN_TYPES = () + FUNCTION = "save_model" + CATEGORY = "WanVideoWrapper" + OUTPUT_NODE = True + DESCRIPTION = "Saves the WanVideo diffusion model (including merged LoRAs) as a safetensors file" + + def save_model(self, model, filename_prefix, save_dtype="same", custom_path=""): + dtype_map = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + } + + # Build output directory + if custom_path and os.path.isabs(custom_path): + output_dir = custom_path + else: + output_dir = os.path.join(folder_paths.models_dir, "diffusion_models") + os.makedirs(output_dir, exist_ok=True) + + # Build filename, avoid overwriting + filename = f"{filename_prefix}.safetensors" + output_path = os.path.join(output_dir, filename) + counter = 1 + while os.path.exists(output_path): + filename = f"{filename_prefix}_{counter}.safetensors" + output_path = os.path.join(output_dir, filename) + counter += 1 + + # Gather metadata about the merge for traceability + metadata = {} + model_name = model.model.pipeline.get("model_name", "unknown") + metadata["source_model"] = str(model_name) + lora_info = model.model.pipeline.get("lora") + if lora_info is not None: + lora_entries = [] + for l in lora_info: + lora_entries.append({ + "name": l.get("name", "unknown"), + "strength": l.get("strength", 1.0), + }) + metadata["merged_loras"] = json.dumps(lora_entries) + metadata["save_dtype"] = save_dtype + + # Extract state dict from the diffusion model (keys are already bare, + # e.g. "blocks.0.self_attn.k.weight" — matching original checkpoint format) + diffusion_model = model.model.diffusion_model + state_dict = diffusion_model.state_dict() + + target_dtype = dtype_map.get(save_dtype) + pbar = ProgressBar(len(state_dict)) + + clean_sd = {} + for k, v in state_dict.items(): + tensor = v.cpu() + if target_dtype is not None: + tensor = tensor.to(target_dtype) + clean_sd[k] = tensor + pbar.update(1) + + log.info(f"Saving merged WanVideo model to: {output_path}") + log.info(f"Number of tensors: {len(clean_sd)}") + + save_file(clean_sd, output_path, metadata=metadata) + + log.info(f"Model saved successfully: {filename}") + del clean_sd + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return () + + +NODE_CLASS_MAPPINGS = { + "WanVideoSaveMergedModel": WanVideoSaveMergedModel, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "WanVideoSaveMergedModel": "WanVideo Save Merged Model", +}