feat: initial release of ComfyUI-LM-Remote

Remote-aware LoRA Manager nodes that fetch metadata via HTTP from a
remote Docker instance while loading LoRA files from local NFS/SMB
mounts. Includes reverse-proxy middleware for transparent web UI access.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 00:46:03 +01:00
commit 980f406573
18 changed files with 2234 additions and 0 deletions

54
__init__.py Normal file
View File

@@ -0,0 +1,54 @@
"""
ComfyUI-LM-Remote — Remote LoRA Manager integration for ComfyUI.
Provides:
1. A reverse-proxy middleware that forwards all LoRA Manager API/UI/WS
requests to a remote Docker instance.
2. Remote-aware node classes that fetch metadata via HTTP instead of the
local ServiceRegistry, while still loading LoRA files from local
NFS/SMB-mounted paths.
Requires the original ComfyUI-Lora-Manager package to be installed alongside
for its widget JS files and custom widget types.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# ── Import node classes ────────────────────────────────────────────────
from .nodes.lora_loader import LoraLoaderRemoteLM, LoraTextLoaderRemoteLM
from .nodes.lora_stacker import LoraStackerRemoteLM
from .nodes.lora_randomizer import LoraRandomizerRemoteLM
from .nodes.lora_cycler import LoraCyclerRemoteLM
from .nodes.lora_pool import LoraPoolRemoteLM
from .nodes.save_image import SaveImageRemoteLM
from .nodes.wanvideo import WanVideoLoraSelectRemoteLM, WanVideoLoraTextSelectRemoteLM
# ── NODE_CLASS_MAPPINGS (how ComfyUI discovers nodes) ──────────────────
NODE_CLASS_MAPPINGS = {
LoraLoaderRemoteLM.NAME: LoraLoaderRemoteLM,
LoraTextLoaderRemoteLM.NAME: LoraTextLoaderRemoteLM,
LoraStackerRemoteLM.NAME: LoraStackerRemoteLM,
LoraRandomizerRemoteLM.NAME: LoraRandomizerRemoteLM,
LoraCyclerRemoteLM.NAME: LoraCyclerRemoteLM,
LoraPoolRemoteLM.NAME: LoraPoolRemoteLM,
SaveImageRemoteLM.NAME: SaveImageRemoteLM,
WanVideoLoraSelectRemoteLM.NAME: WanVideoLoraSelectRemoteLM,
WanVideoLoraTextSelectRemoteLM.NAME: WanVideoLoraTextSelectRemoteLM,
}
# ── WEB_DIRECTORY tells ComfyUI where to find our JS extensions ───────
WEB_DIRECTORY = "./web/comfyui"
# ── Register proxy middleware ──────────────────────────────────────────
try:
from server import PromptServer # type: ignore
from .proxy import register_proxy
register_proxy(PromptServer.instance.app)
except Exception as exc:
logger.warning("[LM-Remote] Could not register proxy middleware: %s", exc)
__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]

5
config.json Normal file
View File

@@ -0,0 +1,5 @@
{
"remote_url": "http://192.168.1.100:8188",
"timeout": 30,
"path_mappings": {}
}

62
config.py Normal file
View File

@@ -0,0 +1,62 @@
"""Configuration for ComfyUI-LM-Remote."""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
logger = logging.getLogger(__name__)
_PACKAGE_DIR = Path(__file__).resolve().parent
_CONFIG_FILE = _PACKAGE_DIR / "config.json"
class RemoteConfig:
"""Holds remote LoRA Manager connection settings."""
def __init__(self):
self.remote_url: str = ""
self.timeout: int = 30
self.path_mappings: dict[str, str] = {}
self._load()
# ------------------------------------------------------------------
def _load(self):
# Environment variable takes priority
env_url = os.environ.get("LM_REMOTE_URL", "")
env_timeout = os.environ.get("LM_REMOTE_TIMEOUT", "")
# Load config.json defaults
if _CONFIG_FILE.exists():
try:
with open(_CONFIG_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
self.remote_url = data.get("remote_url", "")
self.timeout = int(data.get("timeout", 30))
self.path_mappings = data.get("path_mappings", {})
except Exception as exc:
logger.warning("[LM-Remote] Failed to read config.json: %s", exc)
# Env overrides
if env_url:
self.remote_url = env_url
if env_timeout:
self.timeout = int(env_timeout)
# Strip trailing slash
self.remote_url = self.remote_url.rstrip("/")
@property
def is_configured(self) -> bool:
return bool(self.remote_url)
def map_path(self, remote_path: str) -> str:
"""Apply remote->local path prefix mappings."""
for remote_prefix, local_prefix in self.path_mappings.items():
if remote_path.startswith(remote_prefix):
return local_prefix + remote_path[len(remote_prefix):]
return remote_path
remote_config = RemoteConfig()

0
nodes/__init__.py Normal file
View File

92
nodes/lora_cycler.py Normal file
View File

@@ -0,0 +1,92 @@
"""Remote LoRA Cycler — uses remote API instead of local ServiceRegistry."""
from __future__ import annotations
import logging
import os
from .remote_utils import get_lora_info_remote
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
class LoraCyclerRemoteLM:
"""Node that sequentially cycles through LoRAs from a pool (remote)."""
NAME = "Lora Cycler (Remote, LoraManager)"
CATEGORY = "Lora Manager/randomizer"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"cycler_config": ("CYCLER_CONFIG", {}),
},
"optional": {
"pool_config": ("POOL_CONFIG", {}),
},
}
RETURN_TYPES = ("LORA_STACK",)
RETURN_NAMES = ("LORA_STACK",)
FUNCTION = "cycle"
OUTPUT_NODE = False
async def cycle(self, cycler_config, pool_config=None):
current_index = cycler_config.get("current_index", 1)
model_strength = float(cycler_config.get("model_strength", 1.0))
clip_strength = float(cycler_config.get("clip_strength", 1.0))
execution_index = cycler_config.get("execution_index")
client = RemoteLoraClient.get_instance()
lora_list = await client.get_cycler_list(
pool_config=pool_config, sort_by="filename"
)
total_count = len(lora_list)
if total_count == 0:
logger.warning("[LoraCyclerRemoteLM] No LoRAs available in pool")
return {
"result": ([],),
"ui": {
"current_index": [1],
"next_index": [1],
"total_count": [0],
"current_lora_name": [""],
"current_lora_filename": [""],
"error": ["No LoRAs available in pool"],
},
}
actual_index = execution_index if execution_index is not None else current_index
clamped_index = max(1, min(actual_index, total_count))
current_lora = lora_list[clamped_index - 1]
lora_path, _ = get_lora_info_remote(current_lora["file_name"])
if not lora_path:
logger.warning("[LoraCyclerRemoteLM] Could not find path for LoRA: %s", current_lora["file_name"])
lora_stack = []
else:
lora_path = lora_path.replace("/", os.sep)
lora_stack = [(lora_path, model_strength, clip_strength)]
next_index = clamped_index + 1
if next_index > total_count:
next_index = 1
next_lora = lora_list[next_index - 1]
return {
"result": (lora_stack,),
"ui": {
"current_index": [clamped_index],
"next_index": [next_index],
"total_count": [total_count],
"current_lora_name": [current_lora["file_name"]],
"current_lora_filename": [current_lora["file_name"]],
"next_lora_name": [next_lora["file_name"]],
"next_lora_filename": [next_lora["file_name"]],
},
}

205
nodes/lora_loader.py Normal file
View File

@@ -0,0 +1,205 @@
"""Remote LoRA Loader nodes — fetch metadata from the remote LoRA Manager."""
from __future__ import annotations
import logging
import re
from nodes import LoraLoader # type: ignore
from .remote_utils import get_lora_info_remote
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
logger = logging.getLogger(__name__)
class LoraLoaderRemoteLM:
NAME = "Lora Loader (Remote, LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
"placeholder": "Search LoRAs to add...",
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras"
def load_loras(self, model, text, **kwargs):
loaded_loras = []
all_trigger_words = []
clip = kwargs.get("clip", None)
lora_stack = kwargs.get("lora_stack", None)
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
pass
# Process lora_stack
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info_remote(lora_name)
all_trigger_words.extend(trigger_words)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
# Process loras from widget
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get("active", False):
continue
lora_name = lora["name"]
model_strength = float(lora["strength"])
clip_strength = float(lora.get("clipStrength", model_strength))
lora_path, trigger_words = get_lora_info_remote(lora_name)
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for item in loaded_loras:
parts = item.split(":")
lora_name = parts[0]
strength_parts = parts[1].strip().split(",")
if len(strength_parts) > 1:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}:{strength_parts[1].strip()}>")
else:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}>")
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)
class LoraTextLoaderRemoteLM:
NAME = "LoRA Text Loader (Remote, LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"lora_syntax": ("STRING", {
"forceInput": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
}),
},
"optional": {
"clip": ("CLIP",),
"lora_stack": ("LORA_STACK",),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras_from_text"
def parse_lora_syntax(self, text):
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
matches = re.findall(pattern, text, re.IGNORECASE)
loras = []
for match in matches:
loras.append({
"name": match[0],
"model_strength": float(match[1]),
"clip_strength": float(match[2]) if match[2] else float(match[1]),
})
return loras
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
loaded_loras = []
all_trigger_words = []
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
pass
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info_remote(lora_name)
all_trigger_words.extend(trigger_words)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
parsed_loras = self.parse_lora_syntax(lora_syntax)
for lora in parsed_loras:
lora_name = lora["name"]
model_strength = lora["model_strength"]
clip_strength = lora["clip_strength"]
lora_path, trigger_words = get_lora_info_remote(lora_name)
if is_nunchaku_model:
model = nunchaku_load_lora(model, lora_path, model_strength)
else:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for item in loaded_loras:
parts = item.split(":")
lora_name = parts[0].strip()
strength_parts = parts[1].strip().split(",")
if len(strength_parts) > 1:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}:{strength_parts[1].strip()}>")
else:
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}>")
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)

55
nodes/lora_pool.py Normal file
View File

@@ -0,0 +1,55 @@
"""Remote LoRA Pool — pure pass-through, just different NAME."""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
class LoraPoolRemoteLM:
"""LoRA Pool that passes through filter config (remote variant for NAME only)."""
NAME = "Lora Pool (Remote, LoraManager)"
CATEGORY = "Lora Manager/randomizer"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"pool_config": ("LORA_POOL_CONFIG", {}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("POOL_CONFIG",)
RETURN_NAMES = ("POOL_CONFIG",)
FUNCTION = "process"
OUTPUT_NODE = False
def process(self, pool_config, unique_id=None):
if not isinstance(pool_config, dict):
logger.warning("Invalid pool_config type, using empty config")
pool_config = self._default_config()
if "version" not in pool_config:
pool_config["version"] = 1
filters = pool_config.get("filters", self._default_config()["filters"])
logger.debug("[LoraPoolRemoteLM] Processing filters: %s", filters)
return (filters,)
@staticmethod
def _default_config():
return {
"version": 1,
"filters": {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"favoritesOnly": False,
"license": {"noCreditRequired": False, "allowSelling": False},
},
"preview": {"matchCount": 0, "lastUpdated": 0},
}

113
nodes/lora_randomizer.py Normal file
View File

@@ -0,0 +1,113 @@
"""Remote LoRA Randomizer — uses remote API instead of local ServiceRegistry."""
from __future__ import annotations
import logging
import os
from .remote_utils import get_lora_info_remote
from .utils import extract_lora_name
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
class LoraRandomizerRemoteLM:
"""Node that randomly selects LoRAs from a pool (remote metadata)."""
NAME = "Lora Randomizer (Remote, LoraManager)"
CATEGORY = "Lora Manager/randomizer"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"randomizer_config": ("RANDOMIZER_CONFIG", {}),
"loras": ("LORAS", {}),
},
"optional": {
"pool_config": ("POOL_CONFIG", {}),
},
}
RETURN_TYPES = ("LORA_STACK",)
RETURN_NAMES = ("LORA_STACK",)
FUNCTION = "randomize"
OUTPUT_NODE = False
def _preprocess_loras_input(self, loras):
if isinstance(loras, dict) and "__value__" in loras:
return loras["__value__"]
return loras
async def randomize(self, randomizer_config, loras, pool_config=None):
loras = self._preprocess_loras_input(loras)
roll_mode = randomizer_config.get("roll_mode", "always")
execution_seed = randomizer_config.get("execution_seed", None)
next_seed = randomizer_config.get("next_seed", None)
if roll_mode == "fixed":
ui_loras = loras
execution_loras = loras
else:
client = RemoteLoraClient.get_instance()
# Build common kwargs for remote API
api_kwargs = self._build_api_kwargs(randomizer_config, loras, pool_config)
if execution_seed is not None:
exec_kwargs = {**api_kwargs, "seed": execution_seed}
execution_loras = await client.get_random_loras(**exec_kwargs)
if not execution_loras:
execution_loras = loras
else:
execution_loras = loras
ui_kwargs = {**api_kwargs, "seed": next_seed}
ui_loras = await client.get_random_loras(**ui_kwargs)
if not ui_loras:
ui_loras = loras
execution_stack = self._build_execution_stack_from_input(execution_loras)
return {
"result": (execution_stack,),
"ui": {"loras": ui_loras, "last_used": execution_loras},
}
def _build_api_kwargs(self, randomizer_config, input_loras, pool_config):
locked_loras = [l for l in input_loras if l.get("locked", False)]
return {
"count": int(randomizer_config.get("count_fixed", 5)),
"count_mode": randomizer_config.get("count_mode", "range"),
"count_min": int(randomizer_config.get("count_min", 3)),
"count_max": int(randomizer_config.get("count_max", 7)),
"model_strength_min": float(randomizer_config.get("model_strength_min", 0.0)),
"model_strength_max": float(randomizer_config.get("model_strength_max", 1.0)),
"use_same_clip_strength": randomizer_config.get("use_same_clip_strength", True),
"clip_strength_min": float(randomizer_config.get("clip_strength_min", 0.0)),
"clip_strength_max": float(randomizer_config.get("clip_strength_max", 1.0)),
"use_recommended_strength": randomizer_config.get("use_recommended_strength", False),
"recommended_strength_scale_min": float(randomizer_config.get("recommended_strength_scale_min", 0.5)),
"recommended_strength_scale_max": float(randomizer_config.get("recommended_strength_scale_max", 1.0)),
"locked_loras": locked_loras,
"pool_config": pool_config,
}
def _build_execution_stack_from_input(self, loras):
lora_stack = []
for lora in loras:
if not lora.get("active", False):
continue
lora_path, _ = get_lora_info_remote(lora["name"])
if not lora_path:
logger.warning("[LoraRandomizerRemoteLM] Could not find path for LoRA: %s", lora["name"])
continue
lora_path = lora_path.replace("/", os.sep)
model_strength = float(lora.get("strength", 1.0))
clip_strength = float(lora.get("clipStrength", model_strength))
lora_stack.append((lora_path, model_strength, clip_strength))
return lora_stack

71
nodes/lora_stacker.py Normal file
View File

@@ -0,0 +1,71 @@
"""Remote LoRA Stacker — fetch metadata from the remote LoRA Manager."""
from __future__ import annotations
import logging
import os
from .remote_utils import get_lora_info_remote
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
logger = logging.getLogger(__name__)
class LoraStackerRemoteLM:
NAME = "Lora Stacker (Remote, LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
"placeholder": "Search LoRAs to add...",
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("LORA_STACK", "STRING", "STRING")
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
FUNCTION = "stack_loras"
def stack_loras(self, text, **kwargs):
stack = []
active_loras = []
all_trigger_words = []
lora_stack = kwargs.get("lora_stack", None)
if lora_stack:
stack.extend(lora_stack)
for lora_path, _, _ in lora_stack:
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info_remote(lora_name)
all_trigger_words.extend(trigger_words)
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get("active", False):
continue
lora_name = lora["name"]
model_strength = float(lora["strength"])
clip_strength = float(lora.get("clipStrength", model_strength))
lora_path, trigger_words = get_lora_info_remote(lora_name)
stack.append((lora_path.replace("/", os.sep), model_strength, clip_strength))
active_loras.append((lora_name, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, model_strength, clip_strength in active_loras:
if abs(model_strength - clip_strength) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (stack, trigger_words_text, active_loras_text)

44
nodes/remote_utils.py Normal file
View File

@@ -0,0 +1,44 @@
"""Remote replacement for ``py/utils/utils.py:get_lora_info()``.
Same signature: ``get_lora_info_remote(lora_name) -> (relative_path, trigger_words)``
but fetches data from the remote LoRA Manager HTTP API instead of the local
ServiceRegistry / SQLite cache.
"""
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
def get_lora_info_remote(lora_name: str) -> tuple[str, list[str]]:
"""Synchronous wrapper that calls the remote API for LoRA metadata.
Uses the same sync-from-async bridge pattern as the original
``get_lora_info()`` to be a drop-in replacement in node ``FUNCTION`` methods.
"""
async def _fetch():
client = RemoteLoraClient.get_instance()
return await client.get_lora_info(lora_name)
try:
asyncio.get_running_loop()
# Already inside an event loop — run in a separate thread.
def _run_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(_fetch())
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(_run_in_thread)
return future.result()
except RuntimeError:
# No running loop — safe to use asyncio.run()
return asyncio.run(_fetch())

381
nodes/save_image.py Normal file
View File

@@ -0,0 +1,381 @@
"""Remote Save Image — uses remote API for hash lookups."""
from __future__ import annotations
import asyncio
import concurrent.futures
import json
import logging
import os
import re
import folder_paths # type: ignore
import numpy as np
from PIL import Image, PngImagePlugin
from ..remote_client import RemoteLoraClient
logger = logging.getLogger(__name__)
try:
import piexif
except ImportError:
piexif = None
class SaveImageRemoteLM:
NAME = "Save Image (Remote, LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Save images with embedded generation metadata (remote hash lookup)"
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
self.counter = 0
pattern_format = re.compile(r"(%[^%]+%)")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"filename_prefix": ("STRING", {
"default": "ComfyUI",
"tooltip": "Base filename. Supports %seed%, %width%, %height%, %model%, etc.",
}),
"file_format": (["png", "jpeg", "webp"], {
"tooltip": "Image format to save as.",
}),
},
"optional": {
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
},
"hidden": {
"id": "UNIQUE_ID",
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "process_image"
OUTPUT_NODE = True
# ------------------------------------------------------------------
# Remote hash lookups
# ------------------------------------------------------------------
def _run_async(self, coro):
"""Run an async coroutine from sync context."""
try:
asyncio.get_running_loop()
def _in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(_in_thread).result()
except RuntimeError:
return asyncio.run(coro)
def get_lora_hash(self, lora_name):
client = RemoteLoraClient.get_instance()
return self._run_async(client.get_lora_hash(lora_name))
def get_checkpoint_hash(self, checkpoint_path):
if not checkpoint_path:
return None
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
client = RemoteLoraClient.get_instance()
return self._run_async(client.get_checkpoint_hash(checkpoint_name))
# ------------------------------------------------------------------
# Metadata formatting (identical to original)
# ------------------------------------------------------------------
def format_metadata(self, metadata_dict):
if not metadata_dict:
return ""
def add_param_if_not_none(param_list, label, value):
if value is not None:
param_list.append(f"{label}: {value}")
prompt = metadata_dict.get("prompt", "")
negative_prompt = metadata_dict.get("negative_prompt", "")
loras_text = metadata_dict.get("loras", "")
lora_hashes = {}
if loras_text:
prompt_with_loras = f"{prompt}\n{loras_text}"
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
for lora_name, _ in lora_matches:
hash_value = self.get_lora_hash(lora_name)
if hash_value:
lora_hashes[lora_name] = hash_value
else:
prompt_with_loras = prompt
metadata_parts = [prompt_with_loras]
if negative_prompt:
metadata_parts.append(f"Negative prompt: {negative_prompt}")
params = []
if "steps" in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
sampler_name = None
scheduler_name = None
if "sampler" in metadata_dict:
sampler_mapping = {
"euler": "Euler", "euler_ancestral": "Euler a",
"dpm_2": "DPM2", "dpm_2_ancestral": "DPM2 a",
"heun": "Heun", "dpm_fast": "DPM fast",
"dpm_adaptive": "DPM adaptive", "lms": "LMS",
"dpmpp_2s_ancestral": "DPM++ 2S a", "dpmpp_sde": "DPM++ SDE",
"dpmpp_sde_gpu": "DPM++ SDE", "dpmpp_2m": "DPM++ 2M",
"dpmpp_2m_sde": "DPM++ 2M SDE", "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
"ddim": "DDIM",
}
sampler_name = sampler_mapping.get(metadata_dict["sampler"], metadata_dict["sampler"])
if "scheduler" in metadata_dict:
scheduler_mapping = {
"normal": "Simple", "karras": "Karras",
"exponential": "Exponential", "sgm_uniform": "SGM Uniform",
"sgm_quadratic": "SGM Quadratic",
}
scheduler_name = scheduler_mapping.get(metadata_dict["scheduler"], metadata_dict["scheduler"])
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
if "guidance" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
elif "cfg_scale" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
elif "cfg" in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
if "seed" in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
if "size" in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get("size"))
if "checkpoint" in metadata_dict:
checkpoint = metadata_dict.get("checkpoint")
if checkpoint is not None:
model_hash = self.get_checkpoint_hash(checkpoint)
checkpoint_name = os.path.splitext(os.path.basename(checkpoint))[0]
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
else:
params.append(f"Model: {checkpoint_name}")
if lora_hashes:
lora_hash_parts = [f"{n}: {h[:10]}" for n, h in lora_hashes.items()]
params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
metadata_parts.append(", ".join(params))
return "\n".join(metadata_parts)
def format_filename(self, filename, metadata_dict):
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
for segment in result:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and "seed" in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
elif key == "width" and "size" in metadata_dict:
size = metadata_dict.get("size", "x")
w = size.split("x")[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and "size" in metadata_dict:
size = metadata_dict.get("size", "x")
h = size.split("x")[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and "prompt" in metadata_dict:
p = metadata_dict.get("prompt", "").replace("\n", " ")
if len(parts) >= 2:
p = p[: int(parts[1])]
filename = filename.replace(segment, p.strip())
elif key == "nprompt" and "negative_prompt" in metadata_dict:
p = metadata_dict.get("negative_prompt", "").replace("\n", " ")
if len(parts) >= 2:
p = p[: int(parts[1])]
filename = filename.replace(segment, p.strip())
elif key == "model":
model_value = metadata_dict.get("checkpoint")
if isinstance(model_value, (bytes, os.PathLike)):
model_value = str(model_value)
if not isinstance(model_value, str) or not model_value:
model = "model_unavailable"
else:
model = os.path.splitext(os.path.basename(model_value))[0]
if len(parts) >= 2:
model = model[: int(parts[1])]
filename = filename.replace(segment, model)
elif key == "date":
from datetime import datetime
now = datetime.now()
date_table = {
"yyyy": f"{now.year:04d}", "yy": f"{now.year % 100:02d}",
"MM": f"{now.month:02d}", "dd": f"{now.day:02d}",
"hh": f"{now.hour:02d}", "mm": f"{now.minute:02d}",
"ss": f"{now.second:02d}",
}
if len(parts) >= 2:
date_format = parts[1]
else:
date_format = "yyyyMMddhhmmss"
for k, v in date_table.items():
date_format = date_format.replace(k, v)
filename = filename.replace(segment, date_format)
return filename
# ------------------------------------------------------------------
# Image saving
# ------------------------------------------------------------------
def save_images(self, images, filename_prefix, file_format, id, prompt=None,
extra_pnginfo=None, lossless_webp=True, quality=100,
embed_workflow=False, add_counter_to_filename=True):
results = []
# Try to get metadata from the original LoRA Manager's collector.
# The package directory name varies across installs, so we search
# sys.modules for any loaded module whose path ends with the
# expected submodule.
metadata_dict = {}
try:
get_metadata = None
MetadataProcessor = None
import sys
for mod_name, mod in sys.modules.items():
if mod is None:
continue
if mod_name.endswith(".py.metadata_collector") and hasattr(mod, "get_metadata"):
get_metadata = mod.get_metadata
if mod_name.endswith(".py.metadata_collector.metadata_processor") and hasattr(mod, "MetadataProcessor"):
MetadataProcessor = mod.MetadataProcessor
if get_metadata and MetadataProcessor:
break
if get_metadata and MetadataProcessor:
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
else:
logger.debug("[LM-Remote] metadata_collector not found in loaded modules")
except Exception:
logger.debug("[LM-Remote] metadata_collector not available, saving without generation metadata")
metadata = self.format_metadata(metadata_dict)
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder, exist_ok=True)
for i, image in enumerate(images):
img = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
base_filename = filename
if add_counter_to_filename:
current_counter = counter + i
base_filename += f"_{current_counter:05}_"
if file_format == "png":
file = base_filename + ".png"
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg":
file = base_filename + ".jpg"
save_kwargs = {"quality": quality, "optimize": True}
elif file_format == "webp":
file = base_filename + ".webp"
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
file_path = os.path.join(full_output_folder, file)
try:
if file_format == "png":
if metadata:
pnginfo.add_text("parameters", metadata)
if embed_workflow and extra_pnginfo is not None:
pnginfo.add_text("workflow", json.dumps(extra_pnginfo["workflow"]))
save_kwargs["pnginfo"] = pnginfo
img.save(file_path, format="PNG", **save_kwargs)
elif file_format == "jpeg" and piexif:
if metadata:
try:
exif_dict = {"Exif": {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}}
save_kwargs["exif"] = piexif.dump(exif_dict)
except Exception as e:
logger.error("Error adding EXIF data: %s", e)
img.save(file_path, format="JPEG", **save_kwargs)
elif file_format == "webp" and piexif:
try:
exif_dict = {}
if metadata:
exif_dict["Exif"] = {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}
if embed_workflow and extra_pnginfo is not None:
exif_dict["0th"] = {piexif.ImageIFD.ImageDescription: "Workflow:" + json.dumps(extra_pnginfo["workflow"])}
save_kwargs["exif"] = piexif.dump(exif_dict)
except Exception as e:
logger.error("Error adding EXIF data: %s", e)
img.save(file_path, format="WEBP", **save_kwargs)
else:
img.save(file_path)
results.append({"filename": file, "subfolder": subfolder, "type": self.type})
except Exception as e:
logger.error("Error saving image: %s", e)
return results
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png",
prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100,
embed_workflow=False, add_counter_to_filename=True):
os.makedirs(self.output_dir, exist_ok=True)
if isinstance(images, (list, np.ndarray)):
pass
else:
if len(images.shape) == 3:
images = [images]
else:
images = [img for img in images]
self.save_images(
images, filename_prefix, file_format, id, prompt, extra_pnginfo,
lossless_webp, quality, embed_workflow, add_counter_to_filename,
)
return (images,)

141
nodes/utils.py Normal file
View File

@@ -0,0 +1,141 @@
"""Minimal utility classes/functions copied from the original LoRA Manager.
Only the pieces needed by the remote node classes are included here so that
ComfyUI-LM-Remote can function independently of the original package's Python
internals (while still requiring its JS widget files).
"""
from __future__ import annotations
import copy
import logging
import os
import sys
import folder_paths # type: ignore
logger = logging.getLogger(__name__)
class AnyType(str):
"""A special class that is always equal in not-equal comparisons.
Credit to pythongosssss.
"""
def __ne__(self, __value: object) -> bool:
return False
class FlexibleOptionalInputType(dict):
"""Allow flexible/dynamic input types on ComfyUI nodes.
Credit to Regis Gaughan, III (rgthree).
"""
def __init__(self, type):
self.type = type
def __getitem__(self, key):
return (self.type,)
def __contains__(self, key):
return True
any_type = AnyType("*")
def extract_lora_name(lora_path: str) -> str:
"""``'IL\\\\aorunIllstrious.safetensors'`` -> ``'aorunIllstrious'``"""
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def get_loras_list(kwargs: dict) -> list:
"""Extract loras list from either old or new kwargs format."""
if "loras" not in kwargs:
return []
loras_data = kwargs["loras"]
if isinstance(loras_data, dict) and "__value__" in loras_data:
return loras_data["__value__"]
elif isinstance(loras_data, list):
return loras_data
else:
logger.warning("Unexpected loras format: %s", type(loras_data))
return []
# ---------------------------------------------------------------------------
# Nunchaku LoRA helpers (copied verbatim from original)
# ---------------------------------------------------------------------------
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
import safetensors.torch
state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def to_diffusers(input_lora):
import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = {k: v for k, v in input_lora.items()}
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength):
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
if not lora_path or not os.path.isfile(lora_path):
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
return model
model_wrapper = model.model.diffusion_model
module_name = model_wrapper.__class__.__module__
module = sys.modules.get(module_name)
copy_with_ctx = getattr(module, "copy_with_ctx", None)
if copy_with_ctx is not None:
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
else:
logger.warning(
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. "
"Falling back to legacy loading logic."
)
transformer = model_wrapper.model
model_wrapper.model = None
ret_model = copy.deepcopy(model)
ret_model_wrapper = ret_model.model.diffusion_model
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
ret_model_wrapper.loras.append((lora_path, lora_strength))
sd = to_diffusers(lora_path)
if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model

202
nodes/wanvideo.py Normal file
View File

@@ -0,0 +1,202 @@
"""Remote WanVideo LoRA nodes — fetch metadata from the remote LoRA Manager."""
from __future__ import annotations
import logging
import folder_paths # type: ignore
from .remote_utils import get_lora_info_remote
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
logger = logging.getLogger(__name__)
class WanVideoLoraSelectRemoteLM:
NAME = "WanVideo Lora Select (Remote, LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {
"default": False,
"tooltip": "Load LORA models with less VRAM usage, slower loading.",
}),
"merge_loras": ("BOOLEAN", {
"default": True,
"tooltip": "Merge LoRAs into the model.",
}),
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
"placeholder": "Search LoRAs to add...",
"tooltip": "Format: <lora:lora_name:strength>",
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras"
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
loras_list = []
all_trigger_words = []
active_loras = []
prev_lora = kwargs.get("prev_lora", None)
if prev_lora is not None:
loras_list.extend(prev_lora)
if not merge_loras:
low_mem_load = False
blocks = kwargs.get("blocks", {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
loras_from_widget = get_loras_list(kwargs)
for lora in loras_from_widget:
if not lora.get("active", False):
continue
lora_name = lora["name"]
model_strength = float(lora["strength"])
clip_strength = float(lora.get("clipStrength", model_strength))
lora_path, trigger_words = get_lora_info_remote(lora_name)
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
"merge_loras": merge_loras,
}
loras_list.append(lora_item)
active_loras.append((lora_name, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, ms, cs in active_loras:
if abs(ms - cs) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}:{str(cs).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)
class WanVideoLoraTextSelectRemoteLM:
NAME = "WanVideo Lora Select From Text (Remote, LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {
"default": False,
"tooltip": "Load LORA models with less VRAM usage, slower loading.",
}),
"merge_lora": ("BOOLEAN", {
"default": True,
"tooltip": "Merge LoRAs into the model.",
}),
"lora_syntax": ("STRING", {
"multiline": True,
"forceInput": True,
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>",
}),
},
"optional": {
"prev_lora": ("WANVIDLORA",),
"blocks": ("BLOCKS",),
},
}
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras_from_syntax"
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
blocks = kwargs.get("blocks", {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
loras_list = []
all_trigger_words = []
active_loras = []
prev_lora = kwargs.get("prev_lora", None)
if prev_lora is not None:
loras_list.extend(prev_lora)
if not merge_lora:
low_mem_load = False
parts = lora_syntax.split("<lora:")
for part in parts[1:]:
end_index = part.find(">")
if end_index == -1:
continue
content = part[:end_index]
lora_parts = content.split(":")
lora_name_raw = ""
model_strength = 1.0
clip_strength = 1.0
if len(lora_parts) == 2:
lora_name_raw = lora_parts[0].strip()
try:
model_strength = float(lora_parts[1])
clip_strength = model_strength
except (ValueError, IndexError):
logger.warning("Invalid strength for LoRA '%s'. Skipping.", lora_name_raw)
continue
elif len(lora_parts) >= 3:
lora_name_raw = lora_parts[0].strip()
try:
model_strength = float(lora_parts[1])
clip_strength = float(lora_parts[2])
except (ValueError, IndexError):
logger.warning("Invalid strengths for LoRA '%s'. Skipping.", lora_name_raw)
continue
else:
continue
lora_path, trigger_words = get_lora_info_remote(lora_name_raw)
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
"merge_loras": merge_lora,
}
loras_list.append(lora_item)
active_loras.append((lora_name_raw, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, ms, cs in active_loras:
if abs(ms - cs) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}:{str(cs).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)

243
proxy.py Normal file
View File

@@ -0,0 +1,243 @@
"""
Reverse-proxy middleware that forwards LoRA Manager requests to the remote instance.
Registered as an aiohttp middleware on PromptServer.instance.app. It intercepts
requests matching known LoRA Manager URL prefixes and proxies them to the remote
Docker instance. Non-matching requests fall through to the regular ComfyUI router.
Routes that use ``PromptServer.instance.send_sync()`` are explicitly excluded
from proxying so the local original LoRA Manager handler can broadcast events
to the local ComfyUI frontend.
"""
from __future__ import annotations
import asyncio
import logging
import aiohttp
from aiohttp import web, WSMsgType
from .config import remote_config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# URL prefixes that should be forwarded to the remote LoRA Manager
# ---------------------------------------------------------------------------
_PROXY_PREFIXES = (
"/api/lm/",
"/loras_static/",
"/locales/",
"/example_images_static/",
)
# Page routes served by the standalone LoRA Manager web UI
_PROXY_PAGE_ROUTES = {
"/loras",
"/checkpoints",
"/embeddings",
"/loras/recipes",
"/statistics",
}
# WebSocket endpoints to proxy
_WS_ROUTES = {
"/ws/fetch-progress",
"/ws/download-progress",
"/ws/init-progress",
}
# Routes that call send_sync on the remote side — these are NOT proxied.
# Instead they fall through to the local original LoRA Manager handler,
# which broadcasts events to the local ComfyUI frontend. The remote
# handler would broadcast to its own (empty) frontend, which is useless.
#
# These routes:
# /api/lm/loras/get_trigger_words -> trigger_word_update event
# /api/lm/update-lora-code -> lora_code_update event
# /api/lm/update-node-widget -> lm_widget_update event
# /api/lm/register-nodes -> lora_registry_refresh event
_SEND_SYNC_SKIP_ROUTES = {
"/api/lm/loras/get_trigger_words",
"/api/lm/update-lora-code",
"/api/lm/update-node-widget",
"/api/lm/register-nodes",
}
# Shared HTTP session for proxied requests (connection pooling)
_proxy_session: aiohttp.ClientSession | None = None
async def _get_proxy_session() -> aiohttp.ClientSession:
"""Return a shared aiohttp session for HTTP proxy requests."""
global _proxy_session
if _proxy_session is None or _proxy_session.closed:
timeout = aiohttp.ClientTimeout(total=remote_config.timeout)
_proxy_session = aiohttp.ClientSession(timeout=timeout)
return _proxy_session
def _should_proxy(path: str) -> bool:
"""Return True if *path* should be proxied to the remote instance."""
if any(path.startswith(p) for p in _PROXY_PREFIXES):
return True
if path in _PROXY_PAGE_ROUTES or path.rstrip("/") in _PROXY_PAGE_ROUTES:
return True
return False
def _is_ws_route(path: str) -> bool:
return path in _WS_ROUTES
async def _proxy_ws(request: web.Request) -> web.WebSocketResponse:
"""Proxy a WebSocket connection to the remote LoRA Manager."""
remote_url = remote_config.remote_url.replace("http://", "ws://").replace("https://", "wss://")
remote_ws_url = f"{remote_url}{request.path}"
if request.query_string:
remote_ws_url += f"?{request.query_string}"
local_ws = web.WebSocketResponse()
await local_ws.prepare(request)
timeout = aiohttp.ClientTimeout(total=None)
session = aiohttp.ClientSession(timeout=timeout)
try:
async with session.ws_connect(remote_ws_url) as remote_ws:
async def forward_local_to_remote():
async for msg in local_ws:
if msg.type == WSMsgType.TEXT:
await remote_ws.send_str(msg.data)
elif msg.type == WSMsgType.BINARY:
await remote_ws.send_bytes(msg.data)
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
return
async def forward_remote_to_local():
async for msg in remote_ws:
if msg.type == WSMsgType.TEXT:
await local_ws.send_str(msg.data)
elif msg.type == WSMsgType.BINARY:
await local_ws.send_bytes(msg.data)
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
return
# Run both directions concurrently. When either side closes,
# cancel the other to prevent hanging.
task_l2r = asyncio.create_task(forward_local_to_remote())
task_r2l = asyncio.create_task(forward_remote_to_local())
try:
done, pending = await asyncio.wait(
{task_l2r, task_r2l}, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
finally:
# Ensure both sides are closed
if not remote_ws.closed:
await remote_ws.close()
if not local_ws.closed:
await local_ws.close()
except Exception as exc:
logger.warning("[LM-Remote] WebSocket proxy error for %s: %s", request.path, exc)
finally:
await session.close()
return local_ws
async def _proxy_http(request: web.Request) -> web.Response:
"""Forward an HTTP request to the remote LoRA Manager and return its response."""
remote_url = f"{remote_config.remote_url}{request.path}"
if request.query_string:
remote_url += f"?{request.query_string}"
# Read the request body (if any)
body = await request.read() if request.can_read_body else None
# Filter hop-by-hop headers
headers = {}
skip = {"host", "transfer-encoding", "connection", "keep-alive", "upgrade"}
for k, v in request.headers.items():
if k.lower() not in skip:
headers[k] = v
session = await _get_proxy_session()
try:
async with session.request(
method=request.method,
url=remote_url,
headers=headers,
data=body,
) as resp:
resp_body = await resp.read()
resp_headers = {}
for k, v in resp.headers.items():
if k.lower() not in ("transfer-encoding", "content-encoding", "content-length"):
resp_headers[k] = v
return web.Response(
status=resp.status,
body=resp_body,
headers=resp_headers,
)
except Exception as exc:
logger.error("[LM-Remote] Proxy error for %s %s: %s", request.method, request.path, exc)
return web.json_response(
{"error": f"Remote LoRA Manager unavailable: {exc}"},
status=502,
)
# ---------------------------------------------------------------------------
# Middleware factory
# ---------------------------------------------------------------------------
@web.middleware
async def lm_remote_proxy_middleware(request: web.Request, handler):
"""aiohttp middleware that intercepts LoRA Manager requests."""
if not remote_config.is_configured:
return await handler(request)
path = request.path
# Routes that use send_sync must NOT be proxied — let the local
# original LoRA Manager handle them so events reach the local browser.
if path in _SEND_SYNC_SKIP_ROUTES:
return await handler(request)
# WebSocket routes
if _is_ws_route(path):
return await _proxy_ws(request)
# Regular proxy routes
if _should_proxy(path):
return await _proxy_http(request)
# Not a LoRA Manager route — fall through
return await handler(request)
async def _cleanup_proxy_session(app) -> None:
"""Shutdown hook to close the shared proxy session."""
global _proxy_session
if _proxy_session and not _proxy_session.closed:
await _proxy_session.close()
_proxy_session = None
def register_proxy(app) -> None:
"""Insert the proxy middleware into the aiohttp app."""
if not remote_config.is_configured:
logger.warning("[LM-Remote] No remote_url configured — proxy disabled")
return
# Insert at position 0 so we run before the original LoRA Manager routes
app.middlewares.insert(0, lm_remote_proxy_middleware)
app.on_shutdown.append(_cleanup_proxy_session)
logger.info("[LM-Remote] Proxy routes registered -> %s", remote_config.remote_url)

214
remote_client.py Normal file
View File

@@ -0,0 +1,214 @@
"""HTTP client for the remote LoRA Manager instance."""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
import aiohttp
from .config import remote_config
logger = logging.getLogger(__name__)
# Cache TTL in seconds — how long before we re-fetch the full LoRA list
_CACHE_TTL = 60
class RemoteLoraClient:
"""Singleton HTTP client that talks to the remote LoRA Manager.
Uses the actual LoRA Manager REST API endpoints:
- ``GET /api/lm/loras/list?page_size=9999`` — paginated LoRA list
- ``GET /api/lm/loras/get-trigger-words?name=X`` — trigger words
- ``POST /api/lm/loras/random-sample`` — random LoRA selection
- ``POST /api/lm/loras/cycler-list`` — sorted LoRA list for cycler
A short-lived in-memory cache avoids redundant calls to the list endpoint
during a single workflow execution (which may resolve many LoRAs at once).
"""
_instance: RemoteLoraClient | None = None
_session: aiohttp.ClientSession | None = None
def __init__(self):
self._lora_cache: list[dict] = []
self._lora_cache_ts: float = 0
self._checkpoint_cache: list[dict] = []
self._checkpoint_cache_ts: float = 0
@classmethod
def get_instance(cls) -> RemoteLoraClient:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=remote_config.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def close(self):
if self._session and not self._session.closed:
await self._session.close()
self._session = None
# ------------------------------------------------------------------
# Core HTTP helpers
# ------------------------------------------------------------------
async def _get_json(self, path: str, params: dict | None = None) -> Any:
url = f"{remote_config.remote_url}{path}"
session = await self._get_session()
async with session.get(url, params=params) as resp:
resp.raise_for_status()
return await resp.json()
async def _post_json(self, path: str, json_body: dict | None = None) -> Any:
url = f"{remote_config.remote_url}{path}"
session = await self._get_session()
async with session.post(url, json=json_body) as resp:
resp.raise_for_status()
return await resp.json()
# ------------------------------------------------------------------
# Cached list helpers
# ------------------------------------------------------------------
async def _get_lora_list_cached(self) -> list[dict]:
"""Return the full LoRA list, using a short-lived cache."""
now = time.monotonic()
if self._lora_cache and (now - self._lora_cache_ts) < _CACHE_TTL:
return self._lora_cache
try:
data = await self._get_json(
"/api/lm/loras/list", params={"page_size": "9999"}
)
self._lora_cache = data.get("items", [])
self._lora_cache_ts = now
except Exception as exc:
logger.warning("[LM-Remote] Failed to fetch LoRA list: %s", exc)
# Return stale cache on error, or empty list
return self._lora_cache
async def _get_checkpoint_list_cached(self) -> list[dict]:
"""Return the full checkpoint list, using a short-lived cache."""
now = time.monotonic()
if self._checkpoint_cache and (now - self._checkpoint_cache_ts) < _CACHE_TTL:
return self._checkpoint_cache
try:
data = await self._get_json(
"/api/lm/checkpoints/list", params={"page_size": "9999"}
)
self._checkpoint_cache = data.get("items", [])
self._checkpoint_cache_ts = now
except Exception as exc:
logger.warning("[LM-Remote] Failed to fetch checkpoint list: %s", exc)
return self._checkpoint_cache
def _find_item_by_name(self, items: list[dict], name: str) -> dict | None:
"""Find an item in a list by file_name."""
for item in items:
if item.get("file_name") == name:
return item
return None
# ------------------------------------------------------------------
# LoRA metadata
# ------------------------------------------------------------------
async def get_lora_info(self, lora_name: str) -> tuple[str, list[str]]:
"""Return (relative_path, trigger_words) for a LoRA by display name.
Uses the cached ``/api/lm/loras/list`` data. Falls back to the
per-LoRA ``get-trigger-words`` endpoint if the list lookup fails.
"""
import posixpath
try:
items = await self._get_lora_list_cached()
item = self._find_item_by_name(items, lora_name)
if item:
file_path = item.get("file_path", "")
file_path = remote_config.map_path(file_path)
# file_path is the absolute path (forward-slashed) from
# the remote. We need a relative path that the local
# folder_paths.get_full_path("loras", ...) can resolve.
#
# The ``folder`` field gives the subfolder within the
# model root (e.g. "anime" or "anime/characters").
# The basename of file_path has the extension.
#
# Example: file_path="/mnt/loras/anime/test.safetensors"
# folder="anime"
# -> basename="test.safetensors"
# -> relative="anime/test.safetensors"
folder = item.get("folder", "")
basename = posixpath.basename(file_path) # "test.safetensors"
if folder:
relative = f"{folder}/{basename}"
else:
relative = basename
civitai = item.get("civitai") or {}
trigger_words = civitai.get("trainedWords", []) if civitai else []
return relative, trigger_words
# Fallback: try the specific trigger-words endpoint
tw_data = await self._get_json(
"/api/lm/loras/get-trigger-words",
params={"name": lora_name},
)
trigger_words = tw_data.get("trigger_words", [])
return lora_name, trigger_words
except Exception as exc:
logger.warning("[LM-Remote] get_lora_info(%s) failed: %s", lora_name, exc)
return lora_name, []
async def get_lora_hash(self, lora_name: str) -> str | None:
"""Return the SHA-256 hash for a LoRA by display name."""
try:
items = await self._get_lora_list_cached()
item = self._find_item_by_name(items, lora_name)
if item:
return item.get("sha256") or item.get("hash")
except Exception as exc:
logger.warning("[LM-Remote] get_lora_hash(%s) failed: %s", lora_name, exc)
return None
async def get_checkpoint_hash(self, checkpoint_name: str) -> str | None:
"""Return the SHA-256 hash for a checkpoint by display name."""
try:
items = await self._get_checkpoint_list_cached()
item = self._find_item_by_name(items, checkpoint_name)
if item:
return item.get("sha256") or item.get("hash")
except Exception as exc:
logger.warning("[LM-Remote] get_checkpoint_hash(%s) failed: %s", checkpoint_name, exc)
return None
async def get_random_loras(self, **kwargs) -> list[dict]:
"""Ask the remote to generate random LoRAs (for Randomizer node)."""
try:
result = await self._post_json("/api/lm/loras/random-sample", json_body=kwargs)
return result if isinstance(result, list) else result.get("loras", [])
except Exception as exc:
logger.warning("[LM-Remote] get_random_loras failed: %s", exc)
return []
async def get_cycler_list(self, **kwargs) -> list[dict]:
"""Ask the remote for a sorted LoRA list (for Cycler node)."""
try:
result = await self._post_json("/api/lm/loras/cycler-list", json_body=kwargs)
return result if isinstance(result, list) else result.get("loras", [])
except Exception as exc:
logger.warning("[LM-Remote] get_cycler_list failed: %s", exc)
return []

View File

@@ -0,0 +1,166 @@
/**
* JS shim for Remote Lora Loader / Remote Lora Text Loader nodes.
*
* Re-uses all widget infrastructure from the original ComfyUI-Lora-Manager;
* the only difference is matching on the remote node NAMEs.
*/
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import {
collectActiveLorasFromChain,
updateConnectedTriggerWords,
chainCallback,
mergeLoras,
getAllGraphNodes,
getNodeFromGraph,
} from "/extensions/ComfyUI-Lora-Manager/utils.js";
import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js";
import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js";
import { applySelectionHighlight } from "/extensions/ComfyUI-Lora-Manager/trigger_word_highlight.js";
app.registerExtension({
name: "LoraManager.LoraLoaderRemote",
setup() {
api.addEventListener("lora_code_update", (event) => {
this.handleLoraCodeUpdate(event.detail || {});
});
},
handleLoraCodeUpdate(message) {
const nodeId = message?.node_id ?? message?.id;
const graphId = message?.graph_id;
const loraCode = message?.lora_code ?? "";
const mode = message?.mode ?? "append";
const numericNodeId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
if (numericNodeId === -1) {
const loraLoaderNodes = getAllGraphNodes(app.graph)
.map(({ node }) => node)
.filter((node) => node?.comfyClass === "Lora Loader (Remote, LoraManager)");
if (loraLoaderNodes.length > 0) {
loraLoaderNodes.forEach((node) => {
this.updateNodeLoraCode(node, loraCode, mode);
});
}
return;
}
const node = getNodeFromGraph(graphId, numericNodeId);
if (
!node ||
(node.comfyClass !== "Lora Loader (Remote, LoraManager)" &&
node.comfyClass !== "Lora Stacker (Remote, LoraManager)" &&
node.comfyClass !== "WanVideo Lora Select (Remote, LoraManager)")
) {
return;
}
this.updateNodeLoraCode(node, loraCode, mode);
},
updateNodeLoraCode(node, loraCode, mode) {
const inputWidget = node.inputWidget;
if (!inputWidget) return;
const currentValue = inputWidget.value || "";
if (mode === "replace") {
inputWidget.value = loraCode;
} else {
inputWidget.value = currentValue.trim()
? `${currentValue.trim()} ${loraCode}`
: loraCode;
}
if (typeof inputWidget.callback === "function") {
inputWidget.callback(inputWidget.value);
}
},
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeType.comfyClass === "Lora Loader (Remote, LoraManager)") {
chainCallback(nodeType.prototype, "onNodeCreated", function () {
this.serialize_widgets = true;
this.addInput("clip", "CLIP", { shape: 7 });
this.addInput("lora_stack", "LORA_STACK", { shape: 7 });
let isUpdating = false;
let isSyncingInput = false;
const self = this;
let _mode = this.mode;
Object.defineProperty(this, "mode", {
get() { return _mode; },
set(value) {
const oldValue = _mode;
_mode = value;
if (self.onModeChange) self.onModeChange(value, oldValue);
},
});
this.onModeChange = function (newMode) {
const allActiveLoraNames = collectActiveLorasFromChain(self);
updateConnectedTriggerWords(self, allActiveLoraNames);
};
const inputWidget = this.widgets[0];
this.inputWidget = inputWidget;
const scheduleInputSync = debounce((lorasValue) => {
if (isSyncingInput) return;
isSyncingInput = true;
isUpdating = true;
try {
const nextText = applyLoraValuesToText(inputWidget.value, lorasValue);
if (inputWidget.value !== nextText) inputWidget.value = nextText;
} finally {
isUpdating = false;
isSyncingInput = false;
}
});
this.lorasWidget = addLorasWidget(
this, "loras",
{ onSelectionChange: (selection) => applySelectionHighlight(this, selection) },
(value) => {
if (isUpdating) return;
isUpdating = true;
try {
const allActiveLoraNames = collectActiveLorasFromChain(this);
updateConnectedTriggerWords(this, allActiveLoraNames);
} finally {
isUpdating = false;
}
scheduleInputSync(value);
}
).widget;
inputWidget.callback = (value) => {
if (isUpdating) return;
isUpdating = true;
try {
const currentLoras = this.lorasWidget.value || [];
const mergedLoras = mergeLoras(value, currentLoras);
this.lorasWidget.value = mergedLoras;
const allActiveLoraNames = collectActiveLorasFromChain(this);
updateConnectedTriggerWords(this, allActiveLoraNames);
} finally {
isUpdating = false;
}
};
});
}
},
async loadedGraphNode(node) {
if (node.comfyClass === "Lora Loader (Remote, LoraManager)") {
let existingLoras = [];
if (node.widgets_values && node.widgets_values.length > 0) {
existingLoras = node.widgets_values[1] || [];
}
const mergedLoras = mergeLoras(node.widgets[0].value, existingLoras);
node.lorasWidget.value = mergedLoras;
}
},
});

View File

@@ -0,0 +1,97 @@
/**
* JS shim for Remote Lora Stacker node.
*/
import { app } from "../../scripts/app.js";
import {
getActiveLorasFromNode,
updateConnectedTriggerWords,
updateDownstreamLoaders,
chainCallback,
mergeLoras,
} from "/extensions/ComfyUI-Lora-Manager/utils.js";
import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js";
import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js";
import { applySelectionHighlight } from "/extensions/ComfyUI-Lora-Manager/trigger_word_highlight.js";
app.registerExtension({
name: "LoraManager.LoraStackerRemote",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeType.comfyClass === "Lora Stacker (Remote, LoraManager)") {
chainCallback(nodeType.prototype, "onNodeCreated", async function () {
this.serialize_widgets = true;
this.addInput("lora_stack", "LORA_STACK", { shape: 7 });
let isUpdating = false;
let isSyncingInput = false;
const inputWidget = this.widgets[0];
this.inputWidget = inputWidget;
const scheduleInputSync = debounce((lorasValue) => {
if (isSyncingInput) return;
isSyncingInput = true;
isUpdating = true;
try {
const nextText = applyLoraValuesToText(inputWidget.value, lorasValue);
if (inputWidget.value !== nextText) inputWidget.value = nextText;
} finally {
isUpdating = false;
isSyncingInput = false;
}
});
const result = addLorasWidget(
this, "loras",
{ onSelectionChange: (selection) => applySelectionHighlight(this, selection) },
(value) => {
if (isUpdating) return;
isUpdating = true;
try {
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
const activeLoraNames = new Set();
if (isNodeActive) {
value.forEach((lora) => { if (lora.active) activeLoraNames.add(lora.name); });
}
updateConnectedTriggerWords(this, activeLoraNames);
updateDownstreamLoaders(this);
} finally {
isUpdating = false;
}
scheduleInputSync(value);
}
);
this.lorasWidget = result.widget;
inputWidget.callback = (value) => {
if (isUpdating) return;
isUpdating = true;
try {
const currentLoras = this.lorasWidget?.value || [];
const mergedLoras = mergeLoras(value, currentLoras);
if (this.lorasWidget) this.lorasWidget.value = mergedLoras;
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
const activeLoraNames = isNodeActive ? getActiveLorasFromNode(this) : new Set();
updateConnectedTriggerWords(this, activeLoraNames);
updateDownstreamLoaders(this);
} finally {
isUpdating = false;
}
};
});
}
},
async loadedGraphNode(node) {
if (node.comfyClass === "Lora Stacker (Remote, LoraManager)") {
let existingLoras = [];
if (node.widgets_values && node.widgets_values.length > 0) {
existingLoras = node.widgets_values[1] || [];
}
const inputWidget = node.inputWidget || node.widgets[0];
const mergedLoras = mergeLoras(inputWidget.value, existingLoras);
node.lorasWidget.value = mergedLoras;
}
},
});

View File

@@ -0,0 +1,89 @@
/**
* JS shim for Remote WanVideo Lora Select node.
*/
import { app } from "../../scripts/app.js";
import {
getActiveLorasFromNode,
updateConnectedTriggerWords,
chainCallback,
mergeLoras,
} from "/extensions/ComfyUI-Lora-Manager/utils.js";
import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js";
import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js";
app.registerExtension({
name: "LoraManager.WanVideoLoraSelectRemote",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeType.comfyClass === "WanVideo Lora Select (Remote, LoraManager)") {
chainCallback(nodeType.prototype, "onNodeCreated", async function () {
this.serialize_widgets = true;
this.addInput("prev_lora", "WANVIDLORA", { shape: 7 });
this.addInput("blocks", "SELECTEDBLOCKS", { shape: 7 });
let isUpdating = false;
let isSyncingInput = false;
// text widget is at index 2 (after low_mem_load, merge_loras)
const inputWidget = this.widgets[2];
this.inputWidget = inputWidget;
const scheduleInputSync = debounce((lorasValue) => {
if (isSyncingInput) return;
isSyncingInput = true;
isUpdating = true;
try {
const nextText = applyLoraValuesToText(inputWidget.value, lorasValue);
if (inputWidget.value !== nextText) inputWidget.value = nextText;
} finally {
isUpdating = false;
isSyncingInput = false;
}
});
const result = addLorasWidget(this, "loras", {}, (value) => {
if (isUpdating) return;
isUpdating = true;
try {
const activeLoraNames = new Set();
value.forEach((lora) => { if (lora.active) activeLoraNames.add(lora.name); });
updateConnectedTriggerWords(this, activeLoraNames);
} finally {
isUpdating = false;
}
scheduleInputSync(value);
});
this.lorasWidget = result.widget;
inputWidget.callback = (value) => {
if (isUpdating) return;
isUpdating = true;
try {
const currentLoras = this.lorasWidget?.value || [];
const mergedLoras = mergeLoras(value, currentLoras);
if (this.lorasWidget) this.lorasWidget.value = mergedLoras;
const activeLoraNames = getActiveLorasFromNode(this);
updateConnectedTriggerWords(this, activeLoraNames);
} finally {
isUpdating = false;
}
};
});
}
},
async loadedGraphNode(node) {
if (node.comfyClass === "WanVideo Lora Select (Remote, LoraManager)") {
let existingLoras = [];
if (node.widgets_values && node.widgets_values.length > 0) {
// 0=low_mem_load, 1=merge_loras, 2=text, 3=loras
existingLoras = node.widgets_values[3] || [];
}
const inputWidget = node.inputWidget || node.widgets[2];
const mergedLoras = mergeLoras(inputWidget.value, existingLoras);
node.lorasWidget.value = mergedLoras;
}
},
});