feat: initial release of ComfyUI-LM-Remote

Remote-aware LoRA Manager nodes that fetch metadata via HTTP from a
remote Docker instance while loading LoRA files from local NFS/SMB
mounts. Includes reverse-proxy middleware for transparent web UI access.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 00:46:03 +01:00
commit 980f406573
18 changed files with 2234 additions and 0 deletions

0
nodes/__init__.py Normal file
View File

92
nodes/lora_cycler.py Normal file
View File

@@ -0,0 +1,92 @@
"""Remote LoRA Cycler — uses remote API instead of local ServiceRegistry."""
from __future__ import annotations
import logging
import os
from .remote_utils import get_lora_info_remote
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
class LoraCyclerRemoteLM:
"""Node that sequentially cycles through LoRAs from a pool (remote)."""
NAME = "Lora Cycler (Remote, LoraManager)"
CATEGORY = "Lora Manager/randomizer"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"cycler_config": ("CYCLER_CONFIG", {}),
},
"optional": {
"pool_config": ("POOL_CONFIG", {}),
},
}
RETURN_TYPES = ("LORA_STACK",)
RETURN_NAMES = ("LORA_STACK",)
FUNCTION = "cycle"
OUTPUT_NODE = False
async def cycle(self, cycler_config, pool_config=None):
current_index = cycler_config.get("current_index", 1)
model_strength = float(cycler_config.get("model_strength", 1.0))
clip_strength = float(cycler_config.get("clip_strength", 1.0))
execution_index = cycler_config.get("execution_index")
client = RemoteLoraClient.get_instance()
lora_list = await client.get_cycler_list(
pool_config=pool_config, sort_by="filename"
)
total_count = len(lora_list)
if total_count == 0:
logger.warning("[LoraCyclerRemoteLM] No LoRAs available in pool")
return {
"result": ([],),
"ui": {
"current_index": [1],
"next_index": [1],
"total_count": [0],
"current_lora_name": [""],
"current_lora_filename": [""],
"error": ["No LoRAs available in pool"],
},
}
actual_index = execution_index if execution_index is not None else current_index
clamped_index = max(1, min(actual_index, total_count))
current_lora = lora_list[clamped_index - 1]
lora_path, _ = get_lora_info_remote(current_lora["file_name"])
if not lora_path:
logger.warning("[LoraCyclerRemoteLM] Could not find path for LoRA: %s", current_lora["file_name"])
lora_stack = []
else:
lora_path = lora_path.replace("/", os.sep)
lora_stack = [(lora_path, model_strength, clip_strength)]
next_index = clamped_index + 1
if next_index > total_count:
next_index = 1
next_lora = lora_list[next_index - 1]
return {
"result": (lora_stack,),
"ui": {
"current_index": [clamped_index],
"next_index": [next_index],
"total_count": [total_count],
"current_lora_name": [current_lora["file_name"]],
"current_lora_filename": [current_lora["file_name"]],
"next_lora_name": [next_lora["file_name"]],
"next_lora_filename": [next_lora["file_name"]],
},
}

205
nodes/lora_loader.py Normal file
View File

@@ -0,0 +1,205 @@
"""Remote LoRA Loader nodes — fetch metadata from the remote LoRA Manager."""
from __future__ import annotations
import logging
import re
from nodes import LoraLoader # type: ignore
from .remote_utils import get_lora_info_remote
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
logger = logging.getLogger(__name__)
class LoraLoaderRemoteLM:
NAME = "Lora Loader (Remote, LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
"placeholder": "Search LoRAs to add...",
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras"
def load_loras(self, model, text, **kwargs):
loaded_loras = []
all_trigger_words = []
clip = kwargs.get("clip", None)
lora_stack = kwargs.get("lora_stack", None)
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
pass
# Process lora_stack
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info_remote(lora_name)
all_trigger_words.extend(trigger_words)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
# Process loras from widget
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get("active", False):
continue
lora_name = lora["name"]
model_strength = float(lora["strength"])
clip_strength = float(lora.get("clipStrength", model_strength))
lora_path, trigger_words = get_lora_info_remote(lora_name)
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for item in loaded_loras:
parts = item.split(":")
lora_name = parts[0]
strength_parts = parts[1].strip().split(",")
if len(strength_parts) > 1:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}:{strength_parts[1].strip()}>")
else:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}>")
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)
class LoraTextLoaderRemoteLM:
NAME = "LoRA Text Loader (Remote, LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"lora_syntax": ("STRING", {
"forceInput": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
}),
},
"optional": {
"clip": ("CLIP",),
"lora_stack": ("LORA_STACK",),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras_from_text"
def parse_lora_syntax(self, text):
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
matches = re.findall(pattern, text, re.IGNORECASE)
loras = []
for match in matches:
loras.append({
"name": match[0],
"model_strength": float(match[1]),
"clip_strength": float(match[2]) if match[2] else float(match[1]),
})
return loras
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
loaded_loras = []
all_trigger_words = []
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
pass
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info_remote(lora_name)
all_trigger_words.extend(trigger_words)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
parsed_loras = self.parse_lora_syntax(lora_syntax)
for lora in parsed_loras:
lora_name = lora["name"]
model_strength = lora["model_strength"]
clip_strength = lora["clip_strength"]
lora_path, trigger_words = get_lora_info_remote(lora_name)
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for item in loaded_loras:
parts = item.split(":")
lora_name = parts[0].strip()
strength_parts = parts[1].strip().split(",")
if len(strength_parts) > 1:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}:{strength_parts[1].strip()}>")
else:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}>")
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)

55
nodes/lora_pool.py Normal file
View File

@@ -0,0 +1,55 @@
"""Remote LoRA Pool — pure pass-through, just different NAME."""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
class LoraPoolRemoteLM:
"""LoRA Pool that passes through filter config (remote variant for NAME only)."""
NAME = "Lora Pool (Remote, LoraManager)"
CATEGORY = "Lora Manager/randomizer"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"pool_config": ("LORA_POOL_CONFIG", {}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("POOL_CONFIG",)
RETURN_NAMES = ("POOL_CONFIG",)
FUNCTION = "process"
OUTPUT_NODE = False
def process(self, pool_config, unique_id=None):
if not isinstance(pool_config, dict):
logger.warning("Invalid pool_config type, using empty config")
pool_config = self._default_config()
if "version" not in pool_config:
pool_config["version"] = 1
filters = pool_config.get("filters", self._default_config()["filters"])
logger.debug("[LoraPoolRemoteLM] Processing filters: %s", filters)
return (filters,)
@staticmethod
def _default_config():
return {
"version": 1,
"filters": {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"favoritesOnly": False,
"license": {"noCreditRequired": False, "allowSelling": False},
},
"preview": {"matchCount": 0, "lastUpdated": 0},
}

113
nodes/lora_randomizer.py Normal file
View File

@@ -0,0 +1,113 @@
"""Remote LoRA Randomizer — uses remote API instead of local ServiceRegistry."""
from __future__ import annotations
import logging
import os
from .remote_utils import get_lora_info_remote
from .utils import extract_lora_name
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
class LoraRandomizerRemoteLM:
"""Node that randomly selects LoRAs from a pool (remote metadata)."""
NAME = "Lora Randomizer (Remote, LoraManager)"
CATEGORY = "Lora Manager/randomizer"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"randomizer_config": ("RANDOMIZER_CONFIG", {}),
"loras": ("LORAS", {}),
},
"optional": {
"pool_config": ("POOL_CONFIG", {}),
},
}
RETURN_TYPES = ("LORA_STACK",)
RETURN_NAMES = ("LORA_STACK",)
FUNCTION = "randomize"
OUTPUT_NODE = False
def _preprocess_loras_input(self, loras):
if isinstance(loras, dict) and "__value__" in loras:
return loras["__value__"]
return loras
async def randomize(self, randomizer_config, loras, pool_config=None):
loras = self._preprocess_loras_input(loras)
roll_mode = randomizer_config.get("roll_mode", "always")
execution_seed = randomizer_config.get("execution_seed", None)
next_seed = randomizer_config.get("next_seed", None)
if roll_mode == "fixed":
ui_loras = loras
execution_loras = loras
else:
client = RemoteLoraClient.get_instance()
# Build common kwargs for remote API
api_kwargs = self._build_api_kwargs(randomizer_config, loras, pool_config)
if execution_seed is not None:
exec_kwargs = {**api_kwargs, "seed": execution_seed}
execution_loras = await client.get_random_loras(**exec_kwargs)
if not execution_loras:
execution_loras = loras
else:
execution_loras = loras
ui_kwargs = {**api_kwargs, "seed": next_seed}
ui_loras = await client.get_random_loras(**ui_kwargs)
if not ui_loras:
ui_loras = loras
execution_stack = self._build_execution_stack_from_input(execution_loras)
return {
"result": (execution_stack,),
"ui": {"loras": ui_loras, "last_used": execution_loras},
}
def _build_api_kwargs(self, randomizer_config, input_loras, pool_config):
locked_loras = [l for l in input_loras if l.get("locked", False)]
return {
"count": int(randomizer_config.get("count_fixed", 5)),
"count_mode": randomizer_config.get("count_mode", "range"),
"count_min": int(randomizer_config.get("count_min", 3)),
"count_max": int(randomizer_config.get("count_max", 7)),
"model_strength_min": float(randomizer_config.get("model_strength_min", 0.0)),
"model_strength_max": float(randomizer_config.get("model_strength_max", 1.0)),
"use_same_clip_strength": randomizer_config.get("use_same_clip_strength", True),
"clip_strength_min": float(randomizer_config.get("clip_strength_min", 0.0)),
"clip_strength_max": float(randomizer_config.get("clip_strength_max", 1.0)),
"use_recommended_strength": randomizer_config.get("use_recommended_strength", False),
"recommended_strength_scale_min": float(randomizer_config.get("recommended_strength_scale_min", 0.5)),
"recommended_strength_scale_max": float(randomizer_config.get("recommended_strength_scale_max", 1.0)),
"locked_loras": locked_loras,
"pool_config": pool_config,
}
def _build_execution_stack_from_input(self, loras):
lora_stack = []
for lora in loras:
if not lora.get("active", False):
continue
lora_path, _ = get_lora_info_remote(lora["name"])
if not lora_path:
logger.warning("[LoraRandomizerRemoteLM] Could not find path for LoRA: %s", lora["name"])
continue
lora_path = lora_path.replace("/", os.sep)
model_strength = float(lora.get("strength", 1.0))
clip_strength = float(lora.get("clipStrength", model_strength))
lora_stack.append((lora_path, model_strength, clip_strength))
return lora_stack

71
nodes/lora_stacker.py Normal file
View File

@@ -0,0 +1,71 @@
"""Remote LoRA Stacker — fetch metadata from the remote LoRA Manager."""
from __future__ import annotations
import logging
import os
from .remote_utils import get_lora_info_remote
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
logger = logging.getLogger(__name__)
class LoraStackerRemoteLM:
NAME = "Lora Stacker (Remote, LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
"placeholder": "Search LoRAs to add...",
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("LORA_STACK", "STRING", "STRING")
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
FUNCTION = "stack_loras"
def stack_loras(self, text, **kwargs):
stack = []
active_loras = []
all_trigger_words = []
lora_stack = kwargs.get("lora_stack", None)
if lora_stack:
stack.extend(lora_stack)
for lora_path, _, _ in lora_stack:
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info_remote(lora_name)
all_trigger_words.extend(trigger_words)
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get("active", False):
continue
lora_name = lora["name"]
model_strength = float(lora["strength"])
clip_strength = float(lora.get("clipStrength", model_strength))
lora_path, trigger_words = get_lora_info_remote(lora_name)
stack.append((lora_path.replace("/", os.sep), model_strength, clip_strength))
active_loras.append((lora_name, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, model_strength, clip_strength in active_loras:
if abs(model_strength - clip_strength) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (stack, trigger_words_text, active_loras_text)

44
nodes/remote_utils.py Normal file
View File

@@ -0,0 +1,44 @@
"""Remote replacement for ``py/utils/utils.py:get_lora_info()``.
Same signature: ``get_lora_info_remote(lora_name) -> (relative_path, trigger_words)``
but fetches data from the remote LoRA Manager HTTP API instead of the local
ServiceRegistry / SQLite cache.
"""
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
def get_lora_info_remote(lora_name: str) -> tuple[str, list[str]]:
"""Synchronous wrapper that calls the remote API for LoRA metadata.
Uses the same sync-from-async bridge pattern as the original
``get_lora_info()`` to be a drop-in replacement in node ``FUNCTION`` methods.
"""
async def _fetch():
client = RemoteLoraClient.get_instance()
return await client.get_lora_info(lora_name)
try:
asyncio.get_running_loop()
# Already inside an event loop — run in a separate thread.
def _run_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(_fetch())
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(_run_in_thread)
return future.result()
except RuntimeError:
# No running loop — safe to use asyncio.run()
return asyncio.run(_fetch())

381
nodes/save_image.py Normal file
View File

@@ -0,0 +1,381 @@
"""Remote Save Image — uses remote API for hash lookups."""
from __future__ import annotations
import asyncio
import concurrent.futures
import json
import logging
import os
import re
import folder_paths # type: ignore
import numpy as np
from PIL import Image, PngImagePlugin
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
try:
import piexif
except ImportError:
piexif = None
class SaveImageRemoteLM:
NAME = "Save Image (Remote, LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Save images with embedded generation metadata (remote hash lookup)"
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
self.counter = 0
pattern_format = re.compile(r"(%[^%]+%)")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"filename_prefix": ("STRING", {
"default": "ComfyUI",
"tooltip": "Base filename. Supports %seed%, %width%, %height%, %model%, etc.",
}),
"file_format": (["png", "jpeg", "webp"], {
"tooltip": "Image format to save as.",
}),
},
"optional": {
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
},
"hidden": {
"id": "UNIQUE_ID",
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "process_image"
OUTPUT_NODE = True
# ------------------------------------------------------------------
# Remote hash lookups
# ------------------------------------------------------------------
def _run_async(self, coro):
"""Run an async coroutine from sync context."""
try:
asyncio.get_running_loop()
def _in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(_in_thread).result()
except RuntimeError:
return asyncio.run(coro)
def get_lora_hash(self, lora_name):
client = RemoteLoraClient.get_instance()
return self._run_async(client.get_lora_hash(lora_name))
def get_checkpoint_hash(self, checkpoint_path):
if not checkpoint_path:
return None
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
client = RemoteLoraClient.get_instance()
return self._run_async(client.get_checkpoint_hash(checkpoint_name))
# ------------------------------------------------------------------
# Metadata formatting (identical to original)
# ------------------------------------------------------------------
def format_metadata(self, metadata_dict):
if not metadata_dict:
return ""
def add_param_if_not_none(param_list, label, value):
if value is not None:
param_list.append(f"{label}: {value}")
prompt = metadata_dict.get("prompt", "")
negative_prompt = metadata_dict.get("negative_prompt", "")
loras_text = metadata_dict.get("loras", "")
lora_hashes = {}
if loras_text:
prompt_with_loras = f"{prompt}\n{loras_text}"
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
for lora_name, _ in lora_matches:
hash_value = self.get_lora_hash(lora_name)
if hash_value:
lora_hashes[lora_name] = hash_value
else:
prompt_with_loras = prompt
metadata_parts = [prompt_with_loras]
if negative_prompt:
metadata_parts.append(f"Negative prompt: {negative_prompt}")
params = []
if "steps" in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
sampler_name = None
scheduler_name = None
if "sampler" in metadata_dict:
sampler_mapping = {
"euler": "Euler", "euler_ancestral": "Euler a",
"dpm_2": "DPM2", "dpm_2_ancestral": "DPM2 a",
"heun": "Heun", "dpm_fast": "DPM fast",
"dpm_adaptive": "DPM adaptive", "lms": "LMS",
"dpmpp_2s_ancestral": "DPM++ 2S a", "dpmpp_sde": "DPM++ SDE",
"dpmpp_sde_gpu": "DPM++ SDE", "dpmpp_2m": "DPM++ 2M",
"dpmpp_2m_sde": "DPM++ 2M SDE", "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
"ddim": "DDIM",
}
sampler_name = sampler_mapping.get(metadata_dict["sampler"], metadata_dict["sampler"])
if "scheduler" in metadata_dict:
scheduler_mapping = {
"normal": "Simple", "karras": "Karras",
"exponential": "Exponential", "sgm_uniform": "SGM Uniform",
"sgm_quadratic": "SGM Quadratic",
}
scheduler_name = scheduler_mapping.get(metadata_dict["scheduler"], metadata_dict["scheduler"])
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
if "guidance" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
elif "cfg_scale" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
elif "cfg" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
if "seed" in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
if "size" in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get("size"))
if "checkpoint" in metadata_dict:
checkpoint = metadata_dict.get("checkpoint")
if checkpoint is not None:
model_hash = self.get_checkpoint_hash(checkpoint)
checkpoint_name = os.path.splitext(os.path.basename(checkpoint))[0]
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
else:
params.append(f"Model: {checkpoint_name}")
if lora_hashes:
lora_hash_parts = [f"{n}: {h[:10]}" for n, h in lora_hashes.items()]
params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
metadata_parts.append(", ".join(params))
return "\n".join(metadata_parts)
def format_filename(self, filename, metadata_dict):
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
for segment in result:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and "seed" in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
elif key == "width" and "size" in metadata_dict:
size = metadata_dict.get("size", "x")
w = size.split("x")[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and "size" in metadata_dict:
size = metadata_dict.get("size", "x")
h = size.split("x")[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and "prompt" in metadata_dict:
p = metadata_dict.get("prompt", "").replace("\n", " ")
if len(parts) >= 2:
p = p[: int(parts[1])]
filename = filename.replace(segment, p.strip())
elif key == "nprompt" and "negative_prompt" in metadata_dict:
p = metadata_dict.get("negative_prompt", "").replace("\n", " ")
if len(parts) >= 2:
p = p[: int(parts[1])]
filename = filename.replace(segment, p.strip())
elif key == "model":
model_value = metadata_dict.get("checkpoint")
if isinstance(model_value, (bytes, os.PathLike)):
model_value = str(model_value)
if not isinstance(model_value, str) or not model_value:
model = "model_unavailable"
else:
model = os.path.splitext(os.path.basename(model_value))[0]
if len(parts) >= 2:
model = model[: int(parts[1])]
filename = filename.replace(segment, model)
elif key == "date":
from datetime import datetime
now = datetime.now()
date_table = {
"yyyy": f"{now.year:04d}", "yy": f"{now.year % 100:02d}",
"MM": f"{now.month:02d}", "dd": f"{now.day:02d}",
"hh": f"{now.hour:02d}", "mm": f"{now.minute:02d}",
"ss": f"{now.second:02d}",
}
if len(parts) >= 2:
date_format = parts[1]
else:
date_format = "yyyyMMddhhmmss"
for k, v in date_table.items():
date_format = date_format.replace(k, v)
filename = filename.replace(segment, date_format)
return filename
# ------------------------------------------------------------------
# Image saving
# ------------------------------------------------------------------
def save_images(self, images, filename_prefix, file_format, id, prompt=None,
extra_pnginfo=None, lossless_webp=True, quality=100,
embed_workflow=False, add_counter_to_filename=True):
results = []
# Try to get metadata from the original LoRA Manager's collector.
# The package directory name varies across installs, so we search
# sys.modules for any loaded module whose path ends with the
# expected submodule.
metadata_dict = {}
try:
get_metadata = None
MetadataProcessor = None
import sys
for mod_name, mod in sys.modules.items():
if mod is None:
continue
if mod_name.endswith(".py.metadata_collector") and hasattr(mod, "get_metadata"):
get_metadata = mod.get_metadata
if mod_name.endswith(".py.metadata_collector.metadata_processor") and hasattr(mod, "MetadataProcessor"):
MetadataProcessor = mod.MetadataProcessor
if get_metadata and MetadataProcessor:
break
if get_metadata and MetadataProcessor:
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
else:
logger.debug("[LM-Remote] metadata_collector not found in loaded modules")
except Exception:
logger.debug("[LM-Remote] metadata_collector not available, saving without generation metadata")
metadata = self.format_metadata(metadata_dict)
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder, exist_ok=True)
for i, image in enumerate(images):
img = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
base_filename = filename
if add_counter_to_filename:
current_counter = counter + i
base_filename += f"_{current_counter:05}_"
if file_format == "png":
file = base_filename + ".png"
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg":
file = base_filename + ".jpg"
save_kwargs = {"quality": quality, "optimize": True}
elif file_format == "webp":
file = base_filename + ".webp"
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
file_path = os.path.join(full_output_folder, file)
try:
if file_format == "png":
if metadata:
pnginfo.add_text("parameters", metadata)
if embed_workflow and extra_pnginfo is not None:
pnginfo.add_text("workflow", json.dumps(extra_pnginfo["workflow"]))
save_kwargs["pnginfo"] = pnginfo
img.save(file_path, format="PNG", **save_kwargs)
elif file_format == "jpeg" and piexif:
if metadata:
try:
exif_dict = {"Exif": {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}}
save_kwargs["exif"] = piexif.dump(exif_dict)
except Exception as e:
logger.error("Error adding EXIF data: %s", e)
img.save(file_path, format="JPEG", **save_kwargs)
elif file_format == "webp" and piexif:
try:
exif_dict = {}
if metadata:
exif_dict["Exif"] = {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}
if embed_workflow and extra_pnginfo is not None:
exif_dict["0th"] = {piexif.ImageIFD.ImageDescription: "Workflow:" + json.dumps(extra_pnginfo["workflow"])}
save_kwargs["exif"] = piexif.dump(exif_dict)
except Exception as e:
logger.error("Error adding EXIF data: %s", e)
img.save(file_path, format="WEBP", **save_kwargs)
else:
img.save(file_path)
results.append({"filename": file, "subfolder": subfolder, "type": self.type})
except Exception as e:
logger.error("Error saving image: %s", e)
return results
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png",
prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100,
embed_workflow=False, add_counter_to_filename=True):
os.makedirs(self.output_dir, exist_ok=True)
if isinstance(images, (list, np.ndarray)):
pass
else:
if len(images.shape) == 3:
images = [images]
else:
images = [img for img in images]
self.save_images(
images, filename_prefix, file_format, id, prompt, extra_pnginfo,
lossless_webp, quality, embed_workflow, add_counter_to_filename,
)
return (images,)

141
nodes/utils.py Normal file
View File

@@ -0,0 +1,141 @@
"""Minimal utility classes/functions copied from the original LoRA Manager.
Only the pieces needed by the remote node classes are included here so that
ComfyUI-LM-Remote can function independently of the original package's Python
internals (while still requiring its JS widget files).
"""
from __future__ import annotations
import copy
import logging
import os
import sys
import folder_paths # type: ignore
logger = logging.getLogger(__name__)
class AnyType(str):
"""A special class that is always equal in not-equal comparisons.
Credit to pythongosssss.
"""
def __ne__(self, __value: object) -> bool:
return False
class FlexibleOptionalInputType(dict):
"""Allow flexible/dynamic input types on ComfyUI nodes.
Credit to Regis Gaughan, III (rgthree).
"""
def __init__(self, type):
self.type = type
def __getitem__(self, key):
return (self.type,)
def __contains__(self, key):
return True
any_type = AnyType("*")
def extract_lora_name(lora_path: str) -> str:
"""``'IL\\\\aorunIllstrious.safetensors'`` -> ``'aorunIllstrious'``"""
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def get_loras_list(kwargs: dict) -> list:
"""Extract loras list from either old or new kwargs format."""
if "loras" not in kwargs:
return []
loras_data = kwargs["loras"]
if isinstance(loras_data, dict) and "__value__" in loras_data:
return loras_data["__value__"]
elif isinstance(loras_data, list):
return loras_data
else:
logger.warning("Unexpected loras format: %s", type(loras_data))
return []
# ---------------------------------------------------------------------------
# Nunchaku LoRA helpers (copied verbatim from original)
# ---------------------------------------------------------------------------
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
import safetensors.torch
state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def to_diffusers(input_lora):
import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = {k: v for k, v in input_lora.items()}
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength):
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model
model_wrapper = model.model.diffusion_model
module_name = model_wrapper.__class__.__module__
module = sys.modules.get(module_name)
copy_with_ctx = getattr(module, "copy_with_ctx", None)
if copy_with_ctx is not None:
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else:
logger.warning(
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. "
"Falling back to legacy loading logic."
)
transformer = model_wrapper.model
model_wrapper.model = None
ret_model = copy.deepcopy(model)
ret_model_wrapper = ret_model.model.diffusion_model
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
ret_model_wrapper.loras.append((lora_path, lora_strength))
sd = to_diffusers(lora_path)
if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model

202
nodes/wanvideo.py Normal file
View File

@@ -0,0 +1,202 @@
"""Remote WanVideo LoRA nodes — fetch metadata from the remote LoRA Manager."""
from __future__ import annotations
import logging
import folder_paths # type: ignore
from .remote_utils import get_lora_info_remote
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
logger = logging.getLogger(__name__)
class WanVideoLoraSelectRemoteLM:
NAME = "WanVideo Lora Select (Remote, LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {
"default": False,
"tooltip": "Load LORA models with less VRAM usage, slower loading.",
}),
"merge_loras": ("BOOLEAN", {
"default": True,
"tooltip": "Merge LoRAs into the model.",
}),
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
"placeholder": "Search LoRAs to add...",
"tooltip": "Format: <lora:lora_name:strength>",
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras"
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
loras_list = []
all_trigger_words = []
active_loras = []
prev_lora = kwargs.get("prev_lora", None)
if prev_lora is not None:
loras_list.extend(prev_lora)
if not merge_loras:
low_mem_load = False
blocks = kwargs.get("blocks", {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
loras_from_widget = get_loras_list(kwargs)
for lora in loras_from_widget:
if not lora.get("active", False):
continue
lora_name = lora["name"]
model_strength = float(lora["strength"])
clip_strength = float(lora.get("clipStrength", model_strength))
lora_path, trigger_words = get_lora_info_remote(lora_name)
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
"merge_loras": merge_loras,
}
loras_list.append(lora_item)
active_loras.append((lora_name, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, ms, cs in active_loras:
if abs(ms - cs) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}:{str(cs).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)
class WanVideoLoraTextSelectRemoteLM:
NAME = "WanVideo Lora Select From Text (Remote, LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {
"default": False,
"tooltip": "Load LORA models with less VRAM usage, slower loading.",
}),
"merge_lora": ("BOOLEAN", {
"default": True,
"tooltip": "Merge LoRAs into the model.",
}),
"lora_syntax": ("STRING", {
"multiline": True,
"forceInput": True,
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>",
}),
},
"optional": {
"prev_lora": ("WANVIDLORA",),
"blocks": ("BLOCKS",),
},
}
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras_from_syntax"
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
blocks = kwargs.get("blocks", {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
loras_list = []
all_trigger_words = []
active_loras = []
prev_lora = kwargs.get("prev_lora", None)
if prev_lora is not None:
loras_list.extend(prev_lora)
if not merge_lora:
low_mem_load = False
parts = lora_syntax.split("<lora:")
for part in parts[1:]:
end_index = part.find(">")
if end_index == -1:
continue
content = part[:end_index]
lora_parts = content.split(":")
lora_name_raw = ""
model_strength = 1.0
clip_strength = 1.0
if len(lora_parts) == 2:
lora_name_raw = lora_parts[0].strip()
try:
model_strength = float(lora_parts[1])
clip_strength = model_strength
except (ValueError, IndexError):
logger.warning("Invalid strength for LoRA '%s'. Skipping.", lora_name_raw)
continue
elif len(lora_parts) >= 3:
lora_name_raw = lora_parts[0].strip()
try:
model_strength = float(lora_parts[1])
clip_strength = float(lora_parts[2])
except (ValueError, IndexError):
logger.warning("Invalid strengths for LoRA '%s'. Skipping.", lora_name_raw)
continue
else:
continue
lora_path, trigger_words = get_lora_info_remote(lora_name_raw)
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
"merge_loras": merge_lora,
}
loras_list.append(lora_item)
active_loras.append((lora_name_raw, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, ms, cs in active_loras:
if abs(ms - cs) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}:{str(cs).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)