diff --git a/README.md b/README.md index 0240f8b..4c57ce4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ComfyUI-VACE-Tools -A single ComfyUI node that replaces ~149 manually wired nodes for generating VACE mask and control-frame sequences. +ComfyUI custom nodes for WanVideo/VACE workflows — mask/control-frame generation and model saving. ## Installation @@ -9,7 +9,7 @@ cd ComfyUI/custom_nodes/ git clone https://github.com/ethanfel/Comfyui-VACE-Tools.git ``` -Restart ComfyUI. The node appears under the **VACE Tools** category. +Restart ComfyUI. Nodes appear under the **VACE Tools** and **WanVideoWrapper** categories. ## Node: VACE Mask Generator @@ -260,6 +260,28 @@ control_frames: [ k0][ GREY ][ k1][ GREY ][ k2][ GREY ][ k3] | `segment_1` | Full source clip (keyframe images) | | `segment_2`–`4` | Placeholder | +--- + +## Node: WanVideo Save Merged Model + +Saves a WanVideo diffusion model (with merged LoRAs) as a `.safetensors` file. Found under the **WanVideoWrapper** category. + +### Inputs + +| Input | Type | Default | Description | +|---|---|---|---| +| `model` | WANVIDEOMODEL | — | WanVideo model with merged LoRA from the WanVideo Model Loader. | +| `filename_prefix` | STRING | `merged_wanvideo` | Filename prefix for the saved file. A numeric suffix is appended to avoid overwriting. | +| `save_dtype` | ENUM | `same` | Cast weights before saving: `same`, `bf16`, `fp16`, or `fp32`. Set explicitly if the model was loaded in fp8. | +| `custom_path` | STRING | *(optional)* | Absolute path to save directory. Leave empty to save in `ComfyUI/models/diffusion_models/`. | + +### Behavior + +- Extracts the diffusion model state dict and saves it in safetensors format. +- Records source model name and merged LoRA details (names + strengths) in file metadata for traceability. +- Clones all tensors before saving to handle shared/aliased weights safely. +- Automatically avoids overwriting existing files by appending `_1`, `_2`, etc. + ## Dependencies -None beyond PyTorch, which is bundled with ComfyUI. +PyTorch and safetensors, both bundled with ComfyUI. diff --git a/__init__.py b/__init__.py index 39a8c6b..1d1ef9e 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,10 @@ from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +from .save_node import ( + NODE_CLASS_MAPPINGS as SAVE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS as SAVE_DISPLAY_MAPPINGS, +) + +NODE_CLASS_MAPPINGS.update(SAVE_CLASS_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(SAVE_DISPLAY_MAPPINGS) __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/save_node.py b/save_node.py new file mode 100644 index 0000000..5902bc2 --- /dev/null +++ b/save_node.py @@ -0,0 +1,116 @@ +import os +import json +import logging +import torch +import folder_paths +from safetensors.torch import save_file +from comfy.utils import ProgressBar + +log = logging.getLogger("ComfyUI-WanVideoSaveMerged") + + +class WanVideoSaveMergedModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL", {"tooltip": "WanVideo model with merged LoRA from the WanVideo Model Loader"}), + "filename_prefix": ("STRING", {"default": "merged_wanvideo", "tooltip": "Filename prefix for the saved model"}), + }, + "optional": { + "save_dtype": (["same", "bf16", "fp16", "fp32"], { + "default": "same", + "tooltip": "Cast weights to this dtype before saving. 'same' keeps the current dtype of each tensor. Recommended to set explicitly if model was loaded in fp8." + }), + "custom_path": ("STRING", { + "default": "", + "tooltip": "Absolute path to save directory. Leave empty to save in ComfyUI/models/diffusion_models/" + }), + }, + } + + RETURN_TYPES = () + FUNCTION = "save_model" + CATEGORY = "WanVideoWrapper" + OUTPUT_NODE = True + DESCRIPTION = "Saves the WanVideo diffusion model (including merged LoRAs) as a safetensors file" + + def save_model(self, model, filename_prefix, save_dtype="same", custom_path=""): + dtype_map = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + } + + # Build output directory + if custom_path and os.path.isabs(custom_path): + output_dir = custom_path + else: + output_dir = os.path.join(folder_paths.models_dir, "diffusion_models") + os.makedirs(output_dir, exist_ok=True) + + # Build filename, avoid overwriting + filename = f"{filename_prefix}.safetensors" + output_path = os.path.join(output_dir, filename) + counter = 1 + while os.path.exists(output_path): + filename = f"{filename_prefix}_{counter}.safetensors" + output_path = os.path.join(output_dir, filename) + counter += 1 + + # Gather metadata about the merge for traceability + metadata = {} + model_name = model.model.pipeline.get("model_name", "unknown") + metadata["source_model"] = str(model_name) + lora_info = model.model.pipeline.get("lora") + if lora_info is not None: + lora_entries = [] + for l in lora_info: + lora_entries.append({ + "name": l.get("name", "unknown"), + "strength": l.get("strength", 1.0), + }) + metadata["merged_loras"] = json.dumps(lora_entries) + metadata["save_dtype"] = save_dtype + + # Extract state dict from the diffusion model (keys are already bare, + # e.g. "blocks.0.self_attn.k.weight" — matching original checkpoint format) + diffusion_model = model.model.diffusion_model + state_dict = diffusion_model.state_dict() + + target_dtype = dtype_map.get(save_dtype) + pbar = ProgressBar(len(state_dict)) + + clean_sd = {} + for k, v in state_dict.items(): + tensor = v.cpu() + if target_dtype is not None: + tensor = tensor.to(target_dtype) + # Clone to break shared memory between aliased tensors + # (e.g. patch_embedding / expanded_patch_embedding / original_patch_embedding) + # safetensors save_file doesn't handle shared tensors, and save_model + # deduplicates keys which breaks compatibility with ComfyUI's load_file + clean_sd[k] = tensor.clone() + pbar.update(1) + + log.info(f"Saving merged WanVideo model to: {output_path}") + log.info(f"Number of tensors: {len(clean_sd)}") + + save_file(clean_sd, output_path, metadata=metadata) + + log.info(f"Model saved successfully: {filename}") + del clean_sd + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return () + + +NODE_CLASS_MAPPINGS = { + "WanVideoSaveMergedModel": WanVideoSaveMergedModel, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "WanVideoSaveMergedModel": "WanVideo Save Merged Model", +}