node
This commit is contained in:
3
__init__.py
Normal file
3
__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
||||
112
nodes.py
Normal file
112
nodes.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
import folder_paths
|
||||
from safetensors.torch import save_file
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
log = logging.getLogger("ComfyUI-WanVideoSaveMerged")
|
||||
|
||||
|
||||
class WanVideoSaveMergedModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("WANVIDEOMODEL", {"tooltip": "WanVideo model with merged LoRA from the WanVideo Model Loader"}),
|
||||
"filename_prefix": ("STRING", {"default": "merged_wanvideo", "tooltip": "Filename prefix for the saved model"}),
|
||||
},
|
||||
"optional": {
|
||||
"save_dtype": (["same", "bf16", "fp16", "fp32"], {
|
||||
"default": "same",
|
||||
"tooltip": "Cast weights to this dtype before saving. 'same' keeps the current dtype of each tensor. Recommended to set explicitly if model was loaded in fp8."
|
||||
}),
|
||||
"custom_path": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": "Absolute path to save directory. Leave empty to save in ComfyUI/models/diffusion_models/"
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_model"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Saves the WanVideo diffusion model (including merged LoRAs) as a safetensors file"
|
||||
|
||||
def save_model(self, model, filename_prefix, save_dtype="same", custom_path=""):
|
||||
dtype_map = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
# Build output directory
|
||||
if custom_path and os.path.isabs(custom_path):
|
||||
output_dir = custom_path
|
||||
else:
|
||||
output_dir = os.path.join(folder_paths.models_dir, "diffusion_models")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Build filename, avoid overwriting
|
||||
filename = f"{filename_prefix}.safetensors"
|
||||
output_path = os.path.join(output_dir, filename)
|
||||
counter = 1
|
||||
while os.path.exists(output_path):
|
||||
filename = f"{filename_prefix}_{counter}.safetensors"
|
||||
output_path = os.path.join(output_dir, filename)
|
||||
counter += 1
|
||||
|
||||
# Gather metadata about the merge for traceability
|
||||
metadata = {}
|
||||
model_name = model.model.pipeline.get("model_name", "unknown")
|
||||
metadata["source_model"] = str(model_name)
|
||||
lora_info = model.model.pipeline.get("lora")
|
||||
if lora_info is not None:
|
||||
lora_entries = []
|
||||
for l in lora_info:
|
||||
lora_entries.append({
|
||||
"name": l.get("name", "unknown"),
|
||||
"strength": l.get("strength", 1.0),
|
||||
})
|
||||
metadata["merged_loras"] = json.dumps(lora_entries)
|
||||
metadata["save_dtype"] = save_dtype
|
||||
|
||||
# Extract state dict from the diffusion model (keys are already bare,
|
||||
# e.g. "blocks.0.self_attn.k.weight" — matching original checkpoint format)
|
||||
diffusion_model = model.model.diffusion_model
|
||||
state_dict = diffusion_model.state_dict()
|
||||
|
||||
target_dtype = dtype_map.get(save_dtype)
|
||||
pbar = ProgressBar(len(state_dict))
|
||||
|
||||
clean_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
tensor = v.cpu()
|
||||
if target_dtype is not None:
|
||||
tensor = tensor.to(target_dtype)
|
||||
clean_sd[k] = tensor
|
||||
pbar.update(1)
|
||||
|
||||
log.info(f"Saving merged WanVideo model to: {output_path}")
|
||||
log.info(f"Number of tensors: {len(clean_sd)}")
|
||||
|
||||
save_file(clean_sd, output_path, metadata=metadata)
|
||||
|
||||
log.info(f"Model saved successfully: {filename}")
|
||||
del clean_sd
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return ()
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoSaveMergedModel": WanVideoSaveMergedModel,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoSaveMergedModel": "WanVideo Save Merged Model",
|
||||
}
|
||||
Reference in New Issue
Block a user