From 980f406573863f1910a470a9052da16d31f81291 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 22 Feb 2026 00:46:03 +0100 Subject: [PATCH] feat: initial release of ComfyUI-LM-Remote Remote-aware LoRA Manager nodes that fetch metadata via HTTP from a remote Docker instance while loading LoRA files from local NFS/SMB mounts. Includes reverse-proxy middleware for transparent web UI access. Co-Authored-By: Claude Opus 4.6 --- __init__.py | 54 ++++ config.json | 5 + config.py | 62 +++++ nodes/__init__.py | 0 nodes/lora_cycler.py | 92 +++++++ nodes/lora_loader.py | 205 ++++++++++++++++ nodes/lora_pool.py | 55 +++++ nodes/lora_randomizer.py | 113 +++++++++ nodes/lora_stacker.py | 71 ++++++ nodes/remote_utils.py | 44 ++++ nodes/save_image.py | 381 +++++++++++++++++++++++++++++ nodes/utils.py | 141 +++++++++++ nodes/wanvideo.py | 202 +++++++++++++++ proxy.py | 243 ++++++++++++++++++ remote_client.py | 214 ++++++++++++++++ web/comfyui/lora_loader_remote.js | 166 +++++++++++++ web/comfyui/lora_stacker_remote.js | 97 ++++++++ web/comfyui/wanvideo_remote.js | 89 +++++++ 18 files changed, 2234 insertions(+) create mode 100644 __init__.py create mode 100644 config.json create mode 100644 config.py create mode 100644 nodes/__init__.py create mode 100644 nodes/lora_cycler.py create mode 100644 nodes/lora_loader.py create mode 100644 nodes/lora_pool.py create mode 100644 nodes/lora_randomizer.py create mode 100644 nodes/lora_stacker.py create mode 100644 nodes/remote_utils.py create mode 100644 nodes/save_image.py create mode 100644 nodes/utils.py create mode 100644 nodes/wanvideo.py create mode 100644 proxy.py create mode 100644 remote_client.py create mode 100644 web/comfyui/lora_loader_remote.js create mode 100644 web/comfyui/lora_stacker_remote.js create mode 100644 web/comfyui/wanvideo_remote.js diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..a0506a1 --- /dev/null +++ b/__init__.py @@ -0,0 +1,54 @@ +""" +ComfyUI-LM-Remote — Remote LoRA Manager integration for ComfyUI. + +Provides: +1. A reverse-proxy middleware that forwards all LoRA Manager API/UI/WS + requests to a remote Docker instance. +2. Remote-aware node classes that fetch metadata via HTTP instead of the + local ServiceRegistry, while still loading LoRA files from local + NFS/SMB-mounted paths. + +Requires the original ComfyUI-Lora-Manager package to be installed alongside +for its widget JS files and custom widget types. +""" +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# ── Import node classes ──────────────────────────────────────────────── +from .nodes.lora_loader import LoraLoaderRemoteLM, LoraTextLoaderRemoteLM +from .nodes.lora_stacker import LoraStackerRemoteLM +from .nodes.lora_randomizer import LoraRandomizerRemoteLM +from .nodes.lora_cycler import LoraCyclerRemoteLM +from .nodes.lora_pool import LoraPoolRemoteLM +from .nodes.save_image import SaveImageRemoteLM +from .nodes.wanvideo import WanVideoLoraSelectRemoteLM, WanVideoLoraTextSelectRemoteLM + +# ── NODE_CLASS_MAPPINGS (how ComfyUI discovers nodes) ────────────────── +NODE_CLASS_MAPPINGS = { + LoraLoaderRemoteLM.NAME: LoraLoaderRemoteLM, + LoraTextLoaderRemoteLM.NAME: LoraTextLoaderRemoteLM, + LoraStackerRemoteLM.NAME: LoraStackerRemoteLM, + LoraRandomizerRemoteLM.NAME: LoraRandomizerRemoteLM, + LoraCyclerRemoteLM.NAME: LoraCyclerRemoteLM, + LoraPoolRemoteLM.NAME: LoraPoolRemoteLM, + SaveImageRemoteLM.NAME: SaveImageRemoteLM, + WanVideoLoraSelectRemoteLM.NAME: WanVideoLoraSelectRemoteLM, + WanVideoLoraTextSelectRemoteLM.NAME: WanVideoLoraTextSelectRemoteLM, +} + +# ── WEB_DIRECTORY tells ComfyUI where to find our JS extensions ─────── +WEB_DIRECTORY = "./web/comfyui" + +# ── Register proxy middleware ────────────────────────────────────────── +try: + from server import PromptServer # type: ignore + from .proxy import register_proxy + + register_proxy(PromptServer.instance.app) +except Exception as exc: + logger.warning("[LM-Remote] Could not register proxy middleware: %s", exc) + +__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"] diff --git a/config.json b/config.json new file mode 100644 index 0000000..e30d5ee --- /dev/null +++ b/config.json @@ -0,0 +1,5 @@ +{ + "remote_url": "http://192.168.1.100:8188", + "timeout": 30, + "path_mappings": {} +} diff --git a/config.py b/config.py new file mode 100644 index 0000000..a559d05 --- /dev/null +++ b/config.py @@ -0,0 +1,62 @@ +"""Configuration for ComfyUI-LM-Remote.""" +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + +_PACKAGE_DIR = Path(__file__).resolve().parent +_CONFIG_FILE = _PACKAGE_DIR / "config.json" + + +class RemoteConfig: + """Holds remote LoRA Manager connection settings.""" + + def __init__(self): + self.remote_url: str = "" + self.timeout: int = 30 + self.path_mappings: dict[str, str] = {} + self._load() + + # ------------------------------------------------------------------ + def _load(self): + # Environment variable takes priority + env_url = os.environ.get("LM_REMOTE_URL", "") + env_timeout = os.environ.get("LM_REMOTE_TIMEOUT", "") + + # Load config.json defaults + if _CONFIG_FILE.exists(): + try: + with open(_CONFIG_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + self.remote_url = data.get("remote_url", "") + self.timeout = int(data.get("timeout", 30)) + self.path_mappings = data.get("path_mappings", {}) + except Exception as exc: + logger.warning("[LM-Remote] Failed to read config.json: %s", exc) + + # Env overrides + if env_url: + self.remote_url = env_url + if env_timeout: + self.timeout = int(env_timeout) + + # Strip trailing slash + self.remote_url = self.remote_url.rstrip("/") + + @property + def is_configured(self) -> bool: + return bool(self.remote_url) + + def map_path(self, remote_path: str) -> str: + """Apply remote->local path prefix mappings.""" + for remote_prefix, local_prefix in self.path_mappings.items(): + if remote_path.startswith(remote_prefix): + return local_prefix + remote_path[len(remote_prefix):] + return remote_path + + +remote_config = RemoteConfig() diff --git a/nodes/__init__.py b/nodes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nodes/lora_cycler.py b/nodes/lora_cycler.py new file mode 100644 index 0000000..79e1af6 --- /dev/null +++ b/nodes/lora_cycler.py @@ -0,0 +1,92 @@ +"""Remote LoRA Cycler — uses remote API instead of local ServiceRegistry.""" +from __future__ import annotations + +import logging +import os + +from .remote_utils import get_lora_info_remote +from ..remote_client import RemoteLoraClient + +logger = logging.getLogger(__name__) + + +class LoraCyclerRemoteLM: + """Node that sequentially cycles through LoRAs from a pool (remote).""" + + NAME = "Lora Cycler (Remote, LoraManager)" + CATEGORY = "Lora Manager/randomizer" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "cycler_config": ("CYCLER_CONFIG", {}), + }, + "optional": { + "pool_config": ("POOL_CONFIG", {}), + }, + } + + RETURN_TYPES = ("LORA_STACK",) + RETURN_NAMES = ("LORA_STACK",) + FUNCTION = "cycle" + OUTPUT_NODE = False + + async def cycle(self, cycler_config, pool_config=None): + current_index = cycler_config.get("current_index", 1) + model_strength = float(cycler_config.get("model_strength", 1.0)) + clip_strength = float(cycler_config.get("clip_strength", 1.0)) + execution_index = cycler_config.get("execution_index") + + client = RemoteLoraClient.get_instance() + lora_list = await client.get_cycler_list( + pool_config=pool_config, sort_by="filename" + ) + + total_count = len(lora_list) + + if total_count == 0: + logger.warning("[LoraCyclerRemoteLM] No LoRAs available in pool") + return { + "result": ([],), + "ui": { + "current_index": [1], + "next_index": [1], + "total_count": [0], + "current_lora_name": [""], + "current_lora_filename": [""], + "error": ["No LoRAs available in pool"], + }, + } + + actual_index = execution_index if execution_index is not None else current_index + clamped_index = max(1, min(actual_index, total_count)) + + current_lora = lora_list[clamped_index - 1] + + lora_path, _ = get_lora_info_remote(current_lora["file_name"]) + if not lora_path: + logger.warning("[LoraCyclerRemoteLM] Could not find path for LoRA: %s", current_lora["file_name"]) + lora_stack = [] + else: + lora_path = lora_path.replace("/", os.sep) + lora_stack = [(lora_path, model_strength, clip_strength)] + + next_index = clamped_index + 1 + if next_index > total_count: + next_index = 1 + + next_lora = lora_list[next_index - 1] + + return { + "result": (lora_stack,), + "ui": { + "current_index": [clamped_index], + "next_index": [next_index], + "total_count": [total_count], + "current_lora_name": [current_lora["file_name"]], + "current_lora_filename": [current_lora["file_name"]], + "next_lora_name": [next_lora["file_name"]], + "next_lora_filename": [next_lora["file_name"]], + }, + } diff --git a/nodes/lora_loader.py b/nodes/lora_loader.py new file mode 100644 index 0000000..1696d9b --- /dev/null +++ b/nodes/lora_loader.py @@ -0,0 +1,205 @@ +"""Remote LoRA Loader nodes — fetch metadata from the remote LoRA Manager.""" +from __future__ import annotations + +import logging +import re + +from nodes import LoraLoader # type: ignore + +from .remote_utils import get_lora_info_remote +from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora + +logger = logging.getLogger(__name__) + + +class LoraLoaderRemoteLM: + NAME = "Lora Loader (Remote, LoraManager)" + CATEGORY = "Lora Manager/loaders" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "text": ("AUTOCOMPLETE_TEXT_LORAS", { + "placeholder": "Search LoRAs to add...", + "tooltip": "Format: separated by spaces or punctuation", + }), + }, + "optional": FlexibleOptionalInputType(any_type), + } + + RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING") + RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") + FUNCTION = "load_loras" + + def load_loras(self, model, text, **kwargs): + loaded_loras = [] + all_trigger_words = [] + + clip = kwargs.get("clip", None) + lora_stack = kwargs.get("lora_stack", None) + + is_nunchaku_model = False + try: + model_wrapper = model.model.diffusion_model + if model_wrapper.__class__.__name__ == "ComfyFluxWrapper": + is_nunchaku_model = True + logger.info("Detected Nunchaku Flux model") + except (AttributeError, TypeError): + pass + + # Process lora_stack + if lora_stack: + for lora_path, model_strength, clip_strength in lora_stack: + if is_nunchaku_model: + model = nunchaku_load_lora(model, lora_path, model_strength) + else: + model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) + + lora_name = extract_lora_name(lora_path) + _, trigger_words = get_lora_info_remote(lora_name) + all_trigger_words.extend(trigger_words) + if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: + loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") + else: + loaded_loras.append(f"{lora_name}: {model_strength}") + + # Process loras from widget + loras_list = get_loras_list(kwargs) + for lora in loras_list: + if not lora.get("active", False): + continue + + lora_name = lora["name"] + model_strength = float(lora["strength"]) + clip_strength = float(lora.get("clipStrength", model_strength)) + + lora_path, trigger_words = get_lora_info_remote(lora_name) + + if is_nunchaku_model: + model = nunchaku_load_lora(model, lora_path, model_strength) + else: + model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) + + if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: + loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") + else: + loaded_loras.append(f"{lora_name}: {model_strength}") + + all_trigger_words.extend(trigger_words) + + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + formatted_loras = [] + for item in loaded_loras: + parts = item.split(":") + lora_name = parts[0] + strength_parts = parts[1].strip().split(",") + if len(strength_parts) > 1: + formatted_loras.append(f"") + else: + formatted_loras.append(f"") + + formatted_loras_text = " ".join(formatted_loras) + return (model, clip, trigger_words_text, formatted_loras_text) + + +class LoraTextLoaderRemoteLM: + NAME = "LoRA Text Loader (Remote, LoraManager)" + CATEGORY = "Lora Manager/loaders" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "lora_syntax": ("STRING", { + "forceInput": True, + "tooltip": "Format: separated by spaces or punctuation", + }), + }, + "optional": { + "clip": ("CLIP",), + "lora_stack": ("LORA_STACK",), + }, + } + + RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING") + RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") + FUNCTION = "load_loras_from_text" + + def parse_lora_syntax(self, text): + pattern = r"]+):([^:>]+)(?::([^:>]+))?>" + matches = re.findall(pattern, text, re.IGNORECASE) + loras = [] + for match in matches: + loras.append({ + "name": match[0], + "model_strength": float(match[1]), + "clip_strength": float(match[2]) if match[2] else float(match[1]), + }) + return loras + + def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None): + loaded_loras = [] + all_trigger_words = [] + + is_nunchaku_model = False + try: + model_wrapper = model.model.diffusion_model + if model_wrapper.__class__.__name__ == "ComfyFluxWrapper": + is_nunchaku_model = True + logger.info("Detected Nunchaku Flux model") + except (AttributeError, TypeError): + pass + + if lora_stack: + for lora_path, model_strength, clip_strength in lora_stack: + if is_nunchaku_model: + model = nunchaku_load_lora(model, lora_path, model_strength) + else: + model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) + + lora_name = extract_lora_name(lora_path) + _, trigger_words = get_lora_info_remote(lora_name) + all_trigger_words.extend(trigger_words) + if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: + loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") + else: + loaded_loras.append(f"{lora_name}: {model_strength}") + + parsed_loras = self.parse_lora_syntax(lora_syntax) + for lora in parsed_loras: + lora_name = lora["name"] + model_strength = lora["model_strength"] + clip_strength = lora["clip_strength"] + + lora_path, trigger_words = get_lora_info_remote(lora_name) + + if is_nunchaku_model: + model = nunchaku_load_lora(model, lora_path, model_strength) + else: + model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) + + if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001: + loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") + else: + loaded_loras.append(f"{lora_name}: {model_strength}") + + all_trigger_words.extend(trigger_words) + + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + formatted_loras = [] + for item in loaded_loras: + parts = item.split(":") + lora_name = parts[0].strip() + strength_parts = parts[1].strip().split(",") + if len(strength_parts) > 1: + formatted_loras.append(f"") + else: + formatted_loras.append(f"") + + formatted_loras_text = " ".join(formatted_loras) + return (model, clip, trigger_words_text, formatted_loras_text) diff --git a/nodes/lora_pool.py b/nodes/lora_pool.py new file mode 100644 index 0000000..f24fd75 --- /dev/null +++ b/nodes/lora_pool.py @@ -0,0 +1,55 @@ +"""Remote LoRA Pool — pure pass-through, just different NAME.""" +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + + +class LoraPoolRemoteLM: + """LoRA Pool that passes through filter config (remote variant for NAME only).""" + + NAME = "Lora Pool (Remote, LoraManager)" + CATEGORY = "Lora Manager/randomizer" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "pool_config": ("LORA_POOL_CONFIG", {}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("POOL_CONFIG",) + RETURN_NAMES = ("POOL_CONFIG",) + FUNCTION = "process" + OUTPUT_NODE = False + + def process(self, pool_config, unique_id=None): + if not isinstance(pool_config, dict): + logger.warning("Invalid pool_config type, using empty config") + pool_config = self._default_config() + + if "version" not in pool_config: + pool_config["version"] = 1 + + filters = pool_config.get("filters", self._default_config()["filters"]) + logger.debug("[LoraPoolRemoteLM] Processing filters: %s", filters) + return (filters,) + + @staticmethod + def _default_config(): + return { + "version": 1, + "filters": { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "favoritesOnly": False, + "license": {"noCreditRequired": False, "allowSelling": False}, + }, + "preview": {"matchCount": 0, "lastUpdated": 0}, + } diff --git a/nodes/lora_randomizer.py b/nodes/lora_randomizer.py new file mode 100644 index 0000000..6c055dc --- /dev/null +++ b/nodes/lora_randomizer.py @@ -0,0 +1,113 @@ +"""Remote LoRA Randomizer — uses remote API instead of local ServiceRegistry.""" +from __future__ import annotations + +import logging +import os + +from .remote_utils import get_lora_info_remote +from .utils import extract_lora_name +from ..remote_client import RemoteLoraClient + +logger = logging.getLogger(__name__) + + +class LoraRandomizerRemoteLM: + """Node that randomly selects LoRAs from a pool (remote metadata).""" + + NAME = "Lora Randomizer (Remote, LoraManager)" + CATEGORY = "Lora Manager/randomizer" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "randomizer_config": ("RANDOMIZER_CONFIG", {}), + "loras": ("LORAS", {}), + }, + "optional": { + "pool_config": ("POOL_CONFIG", {}), + }, + } + + RETURN_TYPES = ("LORA_STACK",) + RETURN_NAMES = ("LORA_STACK",) + FUNCTION = "randomize" + OUTPUT_NODE = False + + def _preprocess_loras_input(self, loras): + if isinstance(loras, dict) and "__value__" in loras: + return loras["__value__"] + return loras + + async def randomize(self, randomizer_config, loras, pool_config=None): + loras = self._preprocess_loras_input(loras) + + roll_mode = randomizer_config.get("roll_mode", "always") + execution_seed = randomizer_config.get("execution_seed", None) + next_seed = randomizer_config.get("next_seed", None) + + if roll_mode == "fixed": + ui_loras = loras + execution_loras = loras + else: + client = RemoteLoraClient.get_instance() + + # Build common kwargs for remote API + api_kwargs = self._build_api_kwargs(randomizer_config, loras, pool_config) + + if execution_seed is not None: + exec_kwargs = {**api_kwargs, "seed": execution_seed} + execution_loras = await client.get_random_loras(**exec_kwargs) + if not execution_loras: + execution_loras = loras + else: + execution_loras = loras + + ui_kwargs = {**api_kwargs, "seed": next_seed} + ui_loras = await client.get_random_loras(**ui_kwargs) + if not ui_loras: + ui_loras = loras + + execution_stack = self._build_execution_stack_from_input(execution_loras) + + return { + "result": (execution_stack,), + "ui": {"loras": ui_loras, "last_used": execution_loras}, + } + + def _build_api_kwargs(self, randomizer_config, input_loras, pool_config): + locked_loras = [l for l in input_loras if l.get("locked", False)] + return { + "count": int(randomizer_config.get("count_fixed", 5)), + "count_mode": randomizer_config.get("count_mode", "range"), + "count_min": int(randomizer_config.get("count_min", 3)), + "count_max": int(randomizer_config.get("count_max", 7)), + "model_strength_min": float(randomizer_config.get("model_strength_min", 0.0)), + "model_strength_max": float(randomizer_config.get("model_strength_max", 1.0)), + "use_same_clip_strength": randomizer_config.get("use_same_clip_strength", True), + "clip_strength_min": float(randomizer_config.get("clip_strength_min", 0.0)), + "clip_strength_max": float(randomizer_config.get("clip_strength_max", 1.0)), + "use_recommended_strength": randomizer_config.get("use_recommended_strength", False), + "recommended_strength_scale_min": float(randomizer_config.get("recommended_strength_scale_min", 0.5)), + "recommended_strength_scale_max": float(randomizer_config.get("recommended_strength_scale_max", 1.0)), + "locked_loras": locked_loras, + "pool_config": pool_config, + } + + def _build_execution_stack_from_input(self, loras): + lora_stack = [] + for lora in loras: + if not lora.get("active", False): + continue + + lora_path, _ = get_lora_info_remote(lora["name"]) + if not lora_path: + logger.warning("[LoraRandomizerRemoteLM] Could not find path for LoRA: %s", lora["name"]) + continue + + lora_path = lora_path.replace("/", os.sep) + model_strength = float(lora.get("strength", 1.0)) + clip_strength = float(lora.get("clipStrength", model_strength)) + lora_stack.append((lora_path, model_strength, clip_strength)) + + return lora_stack diff --git a/nodes/lora_stacker.py b/nodes/lora_stacker.py new file mode 100644 index 0000000..bf5b0ad --- /dev/null +++ b/nodes/lora_stacker.py @@ -0,0 +1,71 @@ +"""Remote LoRA Stacker — fetch metadata from the remote LoRA Manager.""" +from __future__ import annotations + +import logging +import os + +from .remote_utils import get_lora_info_remote +from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list + +logger = logging.getLogger(__name__) + + +class LoraStackerRemoteLM: + NAME = "Lora Stacker (Remote, LoraManager)" + CATEGORY = "Lora Manager/stackers" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": ("AUTOCOMPLETE_TEXT_LORAS", { + "placeholder": "Search LoRAs to add...", + "tooltip": "Format: separated by spaces or punctuation", + }), + }, + "optional": FlexibleOptionalInputType(any_type), + } + + RETURN_TYPES = ("LORA_STACK", "STRING", "STRING") + RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras") + FUNCTION = "stack_loras" + + def stack_loras(self, text, **kwargs): + stack = [] + active_loras = [] + all_trigger_words = [] + + lora_stack = kwargs.get("lora_stack", None) + if lora_stack: + stack.extend(lora_stack) + for lora_path, _, _ in lora_stack: + lora_name = extract_lora_name(lora_path) + _, trigger_words = get_lora_info_remote(lora_name) + all_trigger_words.extend(trigger_words) + + loras_list = get_loras_list(kwargs) + for lora in loras_list: + if not lora.get("active", False): + continue + + lora_name = lora["name"] + model_strength = float(lora["strength"]) + clip_strength = float(lora.get("clipStrength", model_strength)) + + lora_path, trigger_words = get_lora_info_remote(lora_name) + + stack.append((lora_path.replace("/", os.sep), model_strength, clip_strength)) + active_loras.append((lora_name, model_strength, clip_strength)) + all_trigger_words.extend(trigger_words) + + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + formatted_loras = [] + for name, model_strength, clip_strength in active_loras: + if abs(model_strength - clip_strength) > 0.001: + formatted_loras.append(f"") + else: + formatted_loras.append(f"") + + active_loras_text = " ".join(formatted_loras) + return (stack, trigger_words_text, active_loras_text) diff --git a/nodes/remote_utils.py b/nodes/remote_utils.py new file mode 100644 index 0000000..12a4830 --- /dev/null +++ b/nodes/remote_utils.py @@ -0,0 +1,44 @@ +"""Remote replacement for ``py/utils/utils.py:get_lora_info()``. + +Same signature: ``get_lora_info_remote(lora_name) -> (relative_path, trigger_words)`` +but fetches data from the remote LoRA Manager HTTP API instead of the local +ServiceRegistry / SQLite cache. +""" +from __future__ import annotations + +import asyncio +import concurrent.futures +import logging + +from ..remote_client import RemoteLoraClient + +logger = logging.getLogger(__name__) + + +def get_lora_info_remote(lora_name: str) -> tuple[str, list[str]]: + """Synchronous wrapper that calls the remote API for LoRA metadata. + + Uses the same sync-from-async bridge pattern as the original + ``get_lora_info()`` to be a drop-in replacement in node ``FUNCTION`` methods. + """ + async def _fetch(): + client = RemoteLoraClient.get_instance() + return await client.get_lora_info(lora_name) + + try: + asyncio.get_running_loop() + # Already inside an event loop — run in a separate thread. + def _run_in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(_fetch()) + finally: + loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(_run_in_thread) + return future.result() + except RuntimeError: + # No running loop — safe to use asyncio.run() + return asyncio.run(_fetch()) diff --git a/nodes/save_image.py b/nodes/save_image.py new file mode 100644 index 0000000..3ead413 --- /dev/null +++ b/nodes/save_image.py @@ -0,0 +1,381 @@ +"""Remote Save Image — uses remote API for hash lookups.""" +from __future__ import annotations + +import asyncio +import concurrent.futures +import json +import logging +import os +import re + +import folder_paths # type: ignore +import numpy as np +from PIL import Image, PngImagePlugin + +from ..remote_client import RemoteLoraClient + +logger = logging.getLogger(__name__) + +try: + import piexif +except ImportError: + piexif = None + + +class SaveImageRemoteLM: + NAME = "Save Image (Remote, LoraManager)" + CATEGORY = "Lora Manager/utils" + DESCRIPTION = "Save images with embedded generation metadata (remote hash lookup)" + + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + self.compress_level = 4 + self.counter = 0 + + pattern_format = re.compile(r"(%[^%]+%)") + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "filename_prefix": ("STRING", { + "default": "ComfyUI", + "tooltip": "Base filename. Supports %seed%, %width%, %height%, %model%, etc.", + }), + "file_format": (["png", "jpeg", "webp"], { + "tooltip": "Image format to save as.", + }), + }, + "optional": { + "lossless_webp": ("BOOLEAN", {"default": False}), + "quality": ("INT", {"default": 100, "min": 1, "max": 100}), + "embed_workflow": ("BOOLEAN", {"default": False}), + "add_counter_to_filename": ("BOOLEAN", {"default": True}), + }, + "hidden": { + "id": "UNIQUE_ID", + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO", + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "process_image" + OUTPUT_NODE = True + + # ------------------------------------------------------------------ + # Remote hash lookups + # ------------------------------------------------------------------ + + def _run_async(self, coro): + """Run an async coroutine from sync context.""" + try: + asyncio.get_running_loop() + + def _in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(_in_thread).result() + except RuntimeError: + return asyncio.run(coro) + + def get_lora_hash(self, lora_name): + client = RemoteLoraClient.get_instance() + return self._run_async(client.get_lora_hash(lora_name)) + + def get_checkpoint_hash(self, checkpoint_path): + if not checkpoint_path: + return None + checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0] + client = RemoteLoraClient.get_instance() + return self._run_async(client.get_checkpoint_hash(checkpoint_name)) + + # ------------------------------------------------------------------ + # Metadata formatting (identical to original) + # ------------------------------------------------------------------ + + def format_metadata(self, metadata_dict): + if not metadata_dict: + return "" + + def add_param_if_not_none(param_list, label, value): + if value is not None: + param_list.append(f"{label}: {value}") + + prompt = metadata_dict.get("prompt", "") + negative_prompt = metadata_dict.get("negative_prompt", "") + loras_text = metadata_dict.get("loras", "") + lora_hashes = {} + + if loras_text: + prompt_with_loras = f"{prompt}\n{loras_text}" + lora_matches = re.findall(r"]+)>", loras_text) + for lora_name, _ in lora_matches: + hash_value = self.get_lora_hash(lora_name) + if hash_value: + lora_hashes[lora_name] = hash_value + else: + prompt_with_loras = prompt + + metadata_parts = [prompt_with_loras] + + if negative_prompt: + metadata_parts.append(f"Negative prompt: {negative_prompt}") + + params = [] + + if "steps" in metadata_dict: + add_param_if_not_none(params, "Steps", metadata_dict.get("steps")) + + sampler_name = None + scheduler_name = None + + if "sampler" in metadata_dict: + sampler_mapping = { + "euler": "Euler", "euler_ancestral": "Euler a", + "dpm_2": "DPM2", "dpm_2_ancestral": "DPM2 a", + "heun": "Heun", "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", "lms": "LMS", + "dpmpp_2s_ancestral": "DPM++ 2S a", "dpmpp_sde": "DPM++ SDE", + "dpmpp_sde_gpu": "DPM++ SDE", "dpmpp_2m": "DPM++ 2M", + "dpmpp_2m_sde": "DPM++ 2M SDE", "dpmpp_2m_sde_gpu": "DPM++ 2M SDE", + "ddim": "DDIM", + } + sampler_name = sampler_mapping.get(metadata_dict["sampler"], metadata_dict["sampler"]) + + if "scheduler" in metadata_dict: + scheduler_mapping = { + "normal": "Simple", "karras": "Karras", + "exponential": "Exponential", "sgm_uniform": "SGM Uniform", + "sgm_quadratic": "SGM Quadratic", + } + scheduler_name = scheduler_mapping.get(metadata_dict["scheduler"], metadata_dict["scheduler"]) + + if sampler_name: + if scheduler_name: + params.append(f"Sampler: {sampler_name} {scheduler_name}") + else: + params.append(f"Sampler: {sampler_name}") + + if "guidance" in metadata_dict: + add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance")) + elif "cfg_scale" in metadata_dict: + add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale")) + elif "cfg" in metadata_dict: + add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg")) + + if "seed" in metadata_dict: + add_param_if_not_none(params, "Seed", metadata_dict.get("seed")) + + if "size" in metadata_dict: + add_param_if_not_none(params, "Size", metadata_dict.get("size")) + + if "checkpoint" in metadata_dict: + checkpoint = metadata_dict.get("checkpoint") + if checkpoint is not None: + model_hash = self.get_checkpoint_hash(checkpoint) + checkpoint_name = os.path.splitext(os.path.basename(checkpoint))[0] + if model_hash: + params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}") + else: + params.append(f"Model: {checkpoint_name}") + + if lora_hashes: + lora_hash_parts = [f"{n}: {h[:10]}" for n, h in lora_hashes.items()] + params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"') + + metadata_parts.append(", ".join(params)) + return "\n".join(metadata_parts) + + def format_filename(self, filename, metadata_dict): + if not metadata_dict: + return filename + + result = re.findall(self.pattern_format, filename) + for segment in result: + parts = segment.replace("%", "").split(":") + key = parts[0] + + if key == "seed" and "seed" in metadata_dict: + filename = filename.replace(segment, str(metadata_dict.get("seed", ""))) + elif key == "width" and "size" in metadata_dict: + size = metadata_dict.get("size", "x") + w = size.split("x")[0] if isinstance(size, str) else size[0] + filename = filename.replace(segment, str(w)) + elif key == "height" and "size" in metadata_dict: + size = metadata_dict.get("size", "x") + h = size.split("x")[1] if isinstance(size, str) else size[1] + filename = filename.replace(segment, str(h)) + elif key == "pprompt" and "prompt" in metadata_dict: + p = metadata_dict.get("prompt", "").replace("\n", " ") + if len(parts) >= 2: + p = p[: int(parts[1])] + filename = filename.replace(segment, p.strip()) + elif key == "nprompt" and "negative_prompt" in metadata_dict: + p = metadata_dict.get("negative_prompt", "").replace("\n", " ") + if len(parts) >= 2: + p = p[: int(parts[1])] + filename = filename.replace(segment, p.strip()) + elif key == "model": + model_value = metadata_dict.get("checkpoint") + if isinstance(model_value, (bytes, os.PathLike)): + model_value = str(model_value) + if not isinstance(model_value, str) or not model_value: + model = "model_unavailable" + else: + model = os.path.splitext(os.path.basename(model_value))[0] + if len(parts) >= 2: + model = model[: int(parts[1])] + filename = filename.replace(segment, model) + elif key == "date": + from datetime import datetime + now = datetime.now() + date_table = { + "yyyy": f"{now.year:04d}", "yy": f"{now.year % 100:02d}", + "MM": f"{now.month:02d}", "dd": f"{now.day:02d}", + "hh": f"{now.hour:02d}", "mm": f"{now.minute:02d}", + "ss": f"{now.second:02d}", + } + if len(parts) >= 2: + date_format = parts[1] + else: + date_format = "yyyyMMddhhmmss" + for k, v in date_table.items(): + date_format = date_format.replace(k, v) + filename = filename.replace(segment, date_format) + + return filename + + # ------------------------------------------------------------------ + # Image saving + # ------------------------------------------------------------------ + + def save_images(self, images, filename_prefix, file_format, id, prompt=None, + extra_pnginfo=None, lossless_webp=True, quality=100, + embed_workflow=False, add_counter_to_filename=True): + results = [] + + # Try to get metadata from the original LoRA Manager's collector. + # The package directory name varies across installs, so we search + # sys.modules for any loaded module whose path ends with the + # expected submodule. + metadata_dict = {} + try: + get_metadata = None + MetadataProcessor = None + + import sys + for mod_name, mod in sys.modules.items(): + if mod is None: + continue + if mod_name.endswith(".py.metadata_collector") and hasattr(mod, "get_metadata"): + get_metadata = mod.get_metadata + if mod_name.endswith(".py.metadata_collector.metadata_processor") and hasattr(mod, "MetadataProcessor"): + MetadataProcessor = mod.MetadataProcessor + if get_metadata and MetadataProcessor: + break + + if get_metadata and MetadataProcessor: + raw_metadata = get_metadata() + metadata_dict = MetadataProcessor.to_dict(raw_metadata, id) + else: + logger.debug("[LM-Remote] metadata_collector not found in loaded modules") + except Exception: + logger.debug("[LM-Remote] metadata_collector not available, saving without generation metadata") + + metadata = self.format_metadata(metadata_dict) + filename_prefix = self.format_filename(filename_prefix, metadata_dict) + + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0] + ) + + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder, exist_ok=True) + + for i, image in enumerate(images): + img = 255.0 * image.cpu().numpy() + img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) + + base_filename = filename + if add_counter_to_filename: + current_counter = counter + i + base_filename += f"_{current_counter:05}_" + + if file_format == "png": + file = base_filename + ".png" + save_kwargs = {"compress_level": self.compress_level} + pnginfo = PngImagePlugin.PngInfo() + elif file_format == "jpeg": + file = base_filename + ".jpg" + save_kwargs = {"quality": quality, "optimize": True} + elif file_format == "webp": + file = base_filename + ".webp" + save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0} + + file_path = os.path.join(full_output_folder, file) + + try: + if file_format == "png": + if metadata: + pnginfo.add_text("parameters", metadata) + if embed_workflow and extra_pnginfo is not None: + pnginfo.add_text("workflow", json.dumps(extra_pnginfo["workflow"])) + save_kwargs["pnginfo"] = pnginfo + img.save(file_path, format="PNG", **save_kwargs) + elif file_format == "jpeg" and piexif: + if metadata: + try: + exif_dict = {"Exif": {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}} + save_kwargs["exif"] = piexif.dump(exif_dict) + except Exception as e: + logger.error("Error adding EXIF data: %s", e) + img.save(file_path, format="JPEG", **save_kwargs) + elif file_format == "webp" and piexif: + try: + exif_dict = {} + if metadata: + exif_dict["Exif"] = {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")} + if embed_workflow and extra_pnginfo is not None: + exif_dict["0th"] = {piexif.ImageIFD.ImageDescription: "Workflow:" + json.dumps(extra_pnginfo["workflow"])} + save_kwargs["exif"] = piexif.dump(exif_dict) + except Exception as e: + logger.error("Error adding EXIF data: %s", e) + img.save(file_path, format="WEBP", **save_kwargs) + else: + img.save(file_path) + + results.append({"filename": file, "subfolder": subfolder, "type": self.type}) + except Exception as e: + logger.error("Error saving image: %s", e) + + return results + + def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", + prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100, + embed_workflow=False, add_counter_to_filename=True): + os.makedirs(self.output_dir, exist_ok=True) + + if isinstance(images, (list, np.ndarray)): + pass + else: + if len(images.shape) == 3: + images = [images] + else: + images = [img for img in images] + + self.save_images( + images, filename_prefix, file_format, id, prompt, extra_pnginfo, + lossless_webp, quality, embed_workflow, add_counter_to_filename, + ) + return (images,) diff --git a/nodes/utils.py b/nodes/utils.py new file mode 100644 index 0000000..828cf3d --- /dev/null +++ b/nodes/utils.py @@ -0,0 +1,141 @@ +"""Minimal utility classes/functions copied from the original LoRA Manager. + +Only the pieces needed by the remote node classes are included here so that +ComfyUI-LM-Remote can function independently of the original package's Python +internals (while still requiring its JS widget files). +""" +from __future__ import annotations + +import copy +import logging +import os +import sys + +import folder_paths # type: ignore + +logger = logging.getLogger(__name__) + + +class AnyType(str): + """A special class that is always equal in not-equal comparisons. + + Credit to pythongosssss. + """ + + def __ne__(self, __value: object) -> bool: + return False + + +class FlexibleOptionalInputType(dict): + """Allow flexible/dynamic input types on ComfyUI nodes. + + Credit to Regis Gaughan, III (rgthree). + """ + + def __init__(self, type): + self.type = type + + def __getitem__(self, key): + return (self.type,) + + def __contains__(self, key): + return True + + +any_type = AnyType("*") + + +def extract_lora_name(lora_path: str) -> str: + """``'IL\\\\aorunIllstrious.safetensors'`` -> ``'aorunIllstrious'``""" + basename = os.path.basename(lora_path) + return os.path.splitext(basename)[0] + + +def get_loras_list(kwargs: dict) -> list: + """Extract loras list from either old or new kwargs format.""" + if "loras" not in kwargs: + return [] + loras_data = kwargs["loras"] + if isinstance(loras_data, dict) and "__value__" in loras_data: + return loras_data["__value__"] + elif isinstance(loras_data, list): + return loras_data + else: + logger.warning("Unexpected loras format: %s", type(loras_data)) + return [] + + +# --------------------------------------------------------------------------- +# Nunchaku LoRA helpers (copied verbatim from original) +# --------------------------------------------------------------------------- + +def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""): + import safetensors.torch + + state_dict = {} + with safetensors.torch.safe_open(path, framework="pt", device=device) as f: + for k in f.keys(): + if filter_prefix and not k.startswith(filter_prefix): + continue + state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k) + return state_dict + + +def to_diffusers(input_lora): + import torch + from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft + from diffusers.loaders import FluxLoraLoaderMixin + + if isinstance(input_lora, str): + tensors = load_state_dict_in_safetensors(input_lora, device="cpu") + else: + tensors = {k: v for k, v in input_lora.items()} + + for k, v in tensors.items(): + if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]: + tensors[k] = v.to(torch.bfloat16) + + new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors) + new_tensors = convert_unet_state_dict_to_peft(new_tensors) + return new_tensors + + +def nunchaku_load_lora(model, lora_name, lora_strength): + lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name) + if not lora_path or not os.path.isfile(lora_path): + logger.warning("Skipping LoRA '%s' because it could not be found", lora_name) + return model + + model_wrapper = model.model.diffusion_model + + module_name = model_wrapper.__class__.__module__ + module = sys.modules.get(module_name) + copy_with_ctx = getattr(module, "copy_with_ctx", None) + + if copy_with_ctx is not None: + ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper) + ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)] + else: + logger.warning( + "Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. " + "Falling back to legacy loading logic." + ) + transformer = model_wrapper.model + model_wrapper.model = None + ret_model = copy.deepcopy(model) + ret_model_wrapper = ret_model.model.diffusion_model + model_wrapper.model = transformer + ret_model_wrapper.model = transformer + ret_model_wrapper.loras.append((lora_path, lora_strength)) + + sd = to_diffusers(lora_path) + + if "transformer.x_embedder.lora_A.weight" in sd: + new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1] + assert new_in_channels % 4 == 0 + new_in_channels = new_in_channels // 4 + old_in_channels = ret_model.model.model_config.unet_config["in_channels"] + if old_in_channels < new_in_channels: + ret_model.model.model_config.unet_config["in_channels"] = new_in_channels + + return ret_model diff --git a/nodes/wanvideo.py b/nodes/wanvideo.py new file mode 100644 index 0000000..04510c4 --- /dev/null +++ b/nodes/wanvideo.py @@ -0,0 +1,202 @@ +"""Remote WanVideo LoRA nodes — fetch metadata from the remote LoRA Manager.""" +from __future__ import annotations + +import logging + +import folder_paths # type: ignore + +from .remote_utils import get_lora_info_remote +from .utils import FlexibleOptionalInputType, any_type, get_loras_list + +logger = logging.getLogger(__name__) + + +class WanVideoLoraSelectRemoteLM: + NAME = "WanVideo Lora Select (Remote, LoraManager)" + CATEGORY = "Lora Manager/stackers" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "low_mem_load": ("BOOLEAN", { + "default": False, + "tooltip": "Load LORA models with less VRAM usage, slower loading.", + }), + "merge_loras": ("BOOLEAN", { + "default": True, + "tooltip": "Merge LoRAs into the model.", + }), + "text": ("AUTOCOMPLETE_TEXT_LORAS", { + "placeholder": "Search LoRAs to add...", + "tooltip": "Format: ", + }), + }, + "optional": FlexibleOptionalInputType(any_type), + } + + RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING") + RETURN_NAMES = ("lora", "trigger_words", "active_loras") + FUNCTION = "process_loras" + + def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs): + loras_list = [] + all_trigger_words = [] + active_loras = [] + + prev_lora = kwargs.get("prev_lora", None) + if prev_lora is not None: + loras_list.extend(prev_lora) + + if not merge_loras: + low_mem_load = False + + blocks = kwargs.get("blocks", {}) + selected_blocks = blocks.get("selected_blocks", {}) + layer_filter = blocks.get("layer_filter", "") + + loras_from_widget = get_loras_list(kwargs) + for lora in loras_from_widget: + if not lora.get("active", False): + continue + + lora_name = lora["name"] + model_strength = float(lora["strength"]) + clip_strength = float(lora.get("clipStrength", model_strength)) + + lora_path, trigger_words = get_lora_info_remote(lora_name) + + lora_item = { + "path": folder_paths.get_full_path("loras", lora_path), + "strength": model_strength, + "name": lora_path.split(".")[0], + "blocks": selected_blocks, + "layer_filter": layer_filter, + "low_mem_load": low_mem_load, + "merge_loras": merge_loras, + } + + loras_list.append(lora_item) + active_loras.append((lora_name, model_strength, clip_strength)) + all_trigger_words.extend(trigger_words) + + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + formatted_loras = [] + for name, ms, cs in active_loras: + if abs(ms - cs) > 0.001: + formatted_loras.append(f"") + else: + formatted_loras.append(f"") + + active_loras_text = " ".join(formatted_loras) + return (loras_list, trigger_words_text, active_loras_text) + + +class WanVideoLoraTextSelectRemoteLM: + NAME = "WanVideo Lora Select From Text (Remote, LoraManager)" + CATEGORY = "Lora Manager/stackers" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "low_mem_load": ("BOOLEAN", { + "default": False, + "tooltip": "Load LORA models with less VRAM usage, slower loading.", + }), + "merge_lora": ("BOOLEAN", { + "default": True, + "tooltip": "Merge LoRAs into the model.", + }), + "lora_syntax": ("STRING", { + "multiline": True, + "forceInput": True, + "tooltip": "Connect a TEXT output for LoRA syntax: ", + }), + }, + "optional": { + "prev_lora": ("WANVIDLORA",), + "blocks": ("BLOCKS",), + }, + } + + RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING") + RETURN_NAMES = ("lora", "trigger_words", "active_loras") + FUNCTION = "process_loras_from_syntax" + + def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs): + blocks = kwargs.get("blocks", {}) + selected_blocks = blocks.get("selected_blocks", {}) + layer_filter = blocks.get("layer_filter", "") + + loras_list = [] + all_trigger_words = [] + active_loras = [] + + prev_lora = kwargs.get("prev_lora", None) + if prev_lora is not None: + loras_list.extend(prev_lora) + + if not merge_lora: + low_mem_load = False + + parts = lora_syntax.split("") + if end_index == -1: + continue + + content = part[:end_index] + lora_parts = content.split(":") + + lora_name_raw = "" + model_strength = 1.0 + clip_strength = 1.0 + + if len(lora_parts) == 2: + lora_name_raw = lora_parts[0].strip() + try: + model_strength = float(lora_parts[1]) + clip_strength = model_strength + except (ValueError, IndexError): + logger.warning("Invalid strength for LoRA '%s'. Skipping.", lora_name_raw) + continue + elif len(lora_parts) >= 3: + lora_name_raw = lora_parts[0].strip() + try: + model_strength = float(lora_parts[1]) + clip_strength = float(lora_parts[2]) + except (ValueError, IndexError): + logger.warning("Invalid strengths for LoRA '%s'. Skipping.", lora_name_raw) + continue + else: + continue + + lora_path, trigger_words = get_lora_info_remote(lora_name_raw) + + lora_item = { + "path": folder_paths.get_full_path("loras", lora_path), + "strength": model_strength, + "name": lora_path.split(".")[0], + "blocks": selected_blocks, + "layer_filter": layer_filter, + "low_mem_load": low_mem_load, + "merge_loras": merge_lora, + } + + loras_list.append(lora_item) + active_loras.append((lora_name_raw, model_strength, clip_strength)) + all_trigger_words.extend(trigger_words) + + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + formatted_loras = [] + for name, ms, cs in active_loras: + if abs(ms - cs) > 0.001: + formatted_loras.append(f"") + else: + formatted_loras.append(f"") + + active_loras_text = " ".join(formatted_loras) + return (loras_list, trigger_words_text, active_loras_text) diff --git a/proxy.py b/proxy.py new file mode 100644 index 0000000..63de022 --- /dev/null +++ b/proxy.py @@ -0,0 +1,243 @@ +""" +Reverse-proxy middleware that forwards LoRA Manager requests to the remote instance. + +Registered as an aiohttp middleware on PromptServer.instance.app. It intercepts +requests matching known LoRA Manager URL prefixes and proxies them to the remote +Docker instance. Non-matching requests fall through to the regular ComfyUI router. + +Routes that use ``PromptServer.instance.send_sync()`` are explicitly excluded +from proxying so the local original LoRA Manager handler can broadcast events +to the local ComfyUI frontend. +""" +from __future__ import annotations + +import asyncio +import logging + +import aiohttp +from aiohttp import web, WSMsgType + +from .config import remote_config + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# URL prefixes that should be forwarded to the remote LoRA Manager +# --------------------------------------------------------------------------- +_PROXY_PREFIXES = ( + "/api/lm/", + "/loras_static/", + "/locales/", + "/example_images_static/", +) + +# Page routes served by the standalone LoRA Manager web UI +_PROXY_PAGE_ROUTES = { + "/loras", + "/checkpoints", + "/embeddings", + "/loras/recipes", + "/statistics", +} + +# WebSocket endpoints to proxy +_WS_ROUTES = { + "/ws/fetch-progress", + "/ws/download-progress", + "/ws/init-progress", +} + +# Routes that call send_sync on the remote side — these are NOT proxied. +# Instead they fall through to the local original LoRA Manager handler, +# which broadcasts events to the local ComfyUI frontend. The remote +# handler would broadcast to its own (empty) frontend, which is useless. +# +# These routes: +# /api/lm/loras/get_trigger_words -> trigger_word_update event +# /api/lm/update-lora-code -> lora_code_update event +# /api/lm/update-node-widget -> lm_widget_update event +# /api/lm/register-nodes -> lora_registry_refresh event +_SEND_SYNC_SKIP_ROUTES = { + "/api/lm/loras/get_trigger_words", + "/api/lm/update-lora-code", + "/api/lm/update-node-widget", + "/api/lm/register-nodes", +} + +# Shared HTTP session for proxied requests (connection pooling) +_proxy_session: aiohttp.ClientSession | None = None + + +async def _get_proxy_session() -> aiohttp.ClientSession: + """Return a shared aiohttp session for HTTP proxy requests.""" + global _proxy_session + if _proxy_session is None or _proxy_session.closed: + timeout = aiohttp.ClientTimeout(total=remote_config.timeout) + _proxy_session = aiohttp.ClientSession(timeout=timeout) + return _proxy_session + + +def _should_proxy(path: str) -> bool: + """Return True if *path* should be proxied to the remote instance.""" + if any(path.startswith(p) for p in _PROXY_PREFIXES): + return True + if path in _PROXY_PAGE_ROUTES or path.rstrip("/") in _PROXY_PAGE_ROUTES: + return True + return False + + +def _is_ws_route(path: str) -> bool: + return path in _WS_ROUTES + + +async def _proxy_ws(request: web.Request) -> web.WebSocketResponse: + """Proxy a WebSocket connection to the remote LoRA Manager.""" + remote_url = remote_config.remote_url.replace("http://", "ws://").replace("https://", "wss://") + remote_ws_url = f"{remote_url}{request.path}" + if request.query_string: + remote_ws_url += f"?{request.query_string}" + + local_ws = web.WebSocketResponse() + await local_ws.prepare(request) + + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + try: + async with session.ws_connect(remote_ws_url) as remote_ws: + + async def forward_local_to_remote(): + async for msg in local_ws: + if msg.type == WSMsgType.TEXT: + await remote_ws.send_str(msg.data) + elif msg.type == WSMsgType.BINARY: + await remote_ws.send_bytes(msg.data) + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + return + + async def forward_remote_to_local(): + async for msg in remote_ws: + if msg.type == WSMsgType.TEXT: + await local_ws.send_str(msg.data) + elif msg.type == WSMsgType.BINARY: + await local_ws.send_bytes(msg.data) + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + return + + # Run both directions concurrently. When either side closes, + # cancel the other to prevent hanging. + task_l2r = asyncio.create_task(forward_local_to_remote()) + task_r2l = asyncio.create_task(forward_remote_to_local()) + try: + done, pending = await asyncio.wait( + {task_l2r, task_r2l}, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: + # Ensure both sides are closed + if not remote_ws.closed: + await remote_ws.close() + if not local_ws.closed: + await local_ws.close() + + except Exception as exc: + logger.warning("[LM-Remote] WebSocket proxy error for %s: %s", request.path, exc) + finally: + await session.close() + + return local_ws + + +async def _proxy_http(request: web.Request) -> web.Response: + """Forward an HTTP request to the remote LoRA Manager and return its response.""" + remote_url = f"{remote_config.remote_url}{request.path}" + if request.query_string: + remote_url += f"?{request.query_string}" + + # Read the request body (if any) + body = await request.read() if request.can_read_body else None + + # Filter hop-by-hop headers + headers = {} + skip = {"host", "transfer-encoding", "connection", "keep-alive", "upgrade"} + for k, v in request.headers.items(): + if k.lower() not in skip: + headers[k] = v + + session = await _get_proxy_session() + try: + async with session.request( + method=request.method, + url=remote_url, + headers=headers, + data=body, + ) as resp: + resp_body = await resp.read() + resp_headers = {} + for k, v in resp.headers.items(): + if k.lower() not in ("transfer-encoding", "content-encoding", "content-length"): + resp_headers[k] = v + return web.Response( + status=resp.status, + body=resp_body, + headers=resp_headers, + ) + except Exception as exc: + logger.error("[LM-Remote] Proxy error for %s %s: %s", request.method, request.path, exc) + return web.json_response( + {"error": f"Remote LoRA Manager unavailable: {exc}"}, + status=502, + ) + + +# --------------------------------------------------------------------------- +# Middleware factory +# --------------------------------------------------------------------------- + +@web.middleware +async def lm_remote_proxy_middleware(request: web.Request, handler): + """aiohttp middleware that intercepts LoRA Manager requests.""" + if not remote_config.is_configured: + return await handler(request) + + path = request.path + + # Routes that use send_sync must NOT be proxied — let the local + # original LoRA Manager handle them so events reach the local browser. + if path in _SEND_SYNC_SKIP_ROUTES: + return await handler(request) + + # WebSocket routes + if _is_ws_route(path): + return await _proxy_ws(request) + + # Regular proxy routes + if _should_proxy(path): + return await _proxy_http(request) + + # Not a LoRA Manager route — fall through + return await handler(request) + + +async def _cleanup_proxy_session(app) -> None: + """Shutdown hook to close the shared proxy session.""" + global _proxy_session + if _proxy_session and not _proxy_session.closed: + await _proxy_session.close() + _proxy_session = None + + +def register_proxy(app) -> None: + """Insert the proxy middleware into the aiohttp app.""" + if not remote_config.is_configured: + logger.warning("[LM-Remote] No remote_url configured — proxy disabled") + return + + # Insert at position 0 so we run before the original LoRA Manager routes + app.middlewares.insert(0, lm_remote_proxy_middleware) + app.on_shutdown.append(_cleanup_proxy_session) + logger.info("[LM-Remote] Proxy routes registered -> %s", remote_config.remote_url) diff --git a/remote_client.py b/remote_client.py new file mode 100644 index 0000000..35f4d83 --- /dev/null +++ b/remote_client.py @@ -0,0 +1,214 @@ +"""HTTP client for the remote LoRA Manager instance.""" +from __future__ import annotations + +import asyncio +import logging +import time +from typing import Any + +import aiohttp + +from .config import remote_config + +logger = logging.getLogger(__name__) + +# Cache TTL in seconds — how long before we re-fetch the full LoRA list +_CACHE_TTL = 60 + + +class RemoteLoraClient: + """Singleton HTTP client that talks to the remote LoRA Manager. + + Uses the actual LoRA Manager REST API endpoints: + - ``GET /api/lm/loras/list?page_size=9999`` — paginated LoRA list + - ``GET /api/lm/loras/get-trigger-words?name=X`` — trigger words + - ``POST /api/lm/loras/random-sample`` — random LoRA selection + - ``POST /api/lm/loras/cycler-list`` — sorted LoRA list for cycler + + A short-lived in-memory cache avoids redundant calls to the list endpoint + during a single workflow execution (which may resolve many LoRAs at once). + """ + + _instance: RemoteLoraClient | None = None + _session: aiohttp.ClientSession | None = None + + def __init__(self): + self._lora_cache: list[dict] = [] + self._lora_cache_ts: float = 0 + self._checkpoint_cache: list[dict] = [] + self._checkpoint_cache_ts: float = 0 + + @classmethod + def get_instance(cls) -> RemoteLoraClient: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=remote_config.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + # ------------------------------------------------------------------ + # Core HTTP helpers + # ------------------------------------------------------------------ + + async def _get_json(self, path: str, params: dict | None = None) -> Any: + url = f"{remote_config.remote_url}{path}" + session = await self._get_session() + async with session.get(url, params=params) as resp: + resp.raise_for_status() + return await resp.json() + + async def _post_json(self, path: str, json_body: dict | None = None) -> Any: + url = f"{remote_config.remote_url}{path}" + session = await self._get_session() + async with session.post(url, json=json_body) as resp: + resp.raise_for_status() + return await resp.json() + + # ------------------------------------------------------------------ + # Cached list helpers + # ------------------------------------------------------------------ + + async def _get_lora_list_cached(self) -> list[dict]: + """Return the full LoRA list, using a short-lived cache.""" + now = time.monotonic() + if self._lora_cache and (now - self._lora_cache_ts) < _CACHE_TTL: + return self._lora_cache + + try: + data = await self._get_json( + "/api/lm/loras/list", params={"page_size": "9999"} + ) + self._lora_cache = data.get("items", []) + self._lora_cache_ts = now + except Exception as exc: + logger.warning("[LM-Remote] Failed to fetch LoRA list: %s", exc) + # Return stale cache on error, or empty list + return self._lora_cache + + async def _get_checkpoint_list_cached(self) -> list[dict]: + """Return the full checkpoint list, using a short-lived cache.""" + now = time.monotonic() + if self._checkpoint_cache and (now - self._checkpoint_cache_ts) < _CACHE_TTL: + return self._checkpoint_cache + + try: + data = await self._get_json( + "/api/lm/checkpoints/list", params={"page_size": "9999"} + ) + self._checkpoint_cache = data.get("items", []) + self._checkpoint_cache_ts = now + except Exception as exc: + logger.warning("[LM-Remote] Failed to fetch checkpoint list: %s", exc) + return self._checkpoint_cache + + def _find_item_by_name(self, items: list[dict], name: str) -> dict | None: + """Find an item in a list by file_name.""" + for item in items: + if item.get("file_name") == name: + return item + return None + + # ------------------------------------------------------------------ + # LoRA metadata + # ------------------------------------------------------------------ + + async def get_lora_info(self, lora_name: str) -> tuple[str, list[str]]: + """Return (relative_path, trigger_words) for a LoRA by display name. + + Uses the cached ``/api/lm/loras/list`` data. Falls back to the + per-LoRA ``get-trigger-words`` endpoint if the list lookup fails. + """ + import posixpath + + try: + items = await self._get_lora_list_cached() + item = self._find_item_by_name(items, lora_name) + + if item: + file_path = item.get("file_path", "") + file_path = remote_config.map_path(file_path) + + # file_path is the absolute path (forward-slashed) from + # the remote. We need a relative path that the local + # folder_paths.get_full_path("loras", ...) can resolve. + # + # The ``folder`` field gives the subfolder within the + # model root (e.g. "anime" or "anime/characters"). + # The basename of file_path has the extension. + # + # Example: file_path="/mnt/loras/anime/test.safetensors" + # folder="anime" + # -> basename="test.safetensors" + # -> relative="anime/test.safetensors" + folder = item.get("folder", "") + basename = posixpath.basename(file_path) # "test.safetensors" + + if folder: + relative = f"{folder}/{basename}" + else: + relative = basename + + civitai = item.get("civitai") or {} + trigger_words = civitai.get("trainedWords", []) if civitai else [] + return relative, trigger_words + + # Fallback: try the specific trigger-words endpoint + tw_data = await self._get_json( + "/api/lm/loras/get-trigger-words", + params={"name": lora_name}, + ) + trigger_words = tw_data.get("trigger_words", []) + return lora_name, trigger_words + + except Exception as exc: + logger.warning("[LM-Remote] get_lora_info(%s) failed: %s", lora_name, exc) + return lora_name, [] + + async def get_lora_hash(self, lora_name: str) -> str | None: + """Return the SHA-256 hash for a LoRA by display name.""" + try: + items = await self._get_lora_list_cached() + item = self._find_item_by_name(items, lora_name) + if item: + return item.get("sha256") or item.get("hash") + except Exception as exc: + logger.warning("[LM-Remote] get_lora_hash(%s) failed: %s", lora_name, exc) + return None + + async def get_checkpoint_hash(self, checkpoint_name: str) -> str | None: + """Return the SHA-256 hash for a checkpoint by display name.""" + try: + items = await self._get_checkpoint_list_cached() + item = self._find_item_by_name(items, checkpoint_name) + if item: + return item.get("sha256") or item.get("hash") + except Exception as exc: + logger.warning("[LM-Remote] get_checkpoint_hash(%s) failed: %s", checkpoint_name, exc) + return None + + async def get_random_loras(self, **kwargs) -> list[dict]: + """Ask the remote to generate random LoRAs (for Randomizer node).""" + try: + result = await self._post_json("/api/lm/loras/random-sample", json_body=kwargs) + return result if isinstance(result, list) else result.get("loras", []) + except Exception as exc: + logger.warning("[LM-Remote] get_random_loras failed: %s", exc) + return [] + + async def get_cycler_list(self, **kwargs) -> list[dict]: + """Ask the remote for a sorted LoRA list (for Cycler node).""" + try: + result = await self._post_json("/api/lm/loras/cycler-list", json_body=kwargs) + return result if isinstance(result, list) else result.get("loras", []) + except Exception as exc: + logger.warning("[LM-Remote] get_cycler_list failed: %s", exc) + return [] diff --git a/web/comfyui/lora_loader_remote.js b/web/comfyui/lora_loader_remote.js new file mode 100644 index 0000000..0b40ebc --- /dev/null +++ b/web/comfyui/lora_loader_remote.js @@ -0,0 +1,166 @@ +/** + * JS shim for Remote Lora Loader / Remote Lora Text Loader nodes. + * + * Re-uses all widget infrastructure from the original ComfyUI-Lora-Manager; + * the only difference is matching on the remote node NAMEs. + */ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { + collectActiveLorasFromChain, + updateConnectedTriggerWords, + chainCallback, + mergeLoras, + getAllGraphNodes, + getNodeFromGraph, +} from "/extensions/ComfyUI-Lora-Manager/utils.js"; +import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js"; +import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js"; +import { applySelectionHighlight } from "/extensions/ComfyUI-Lora-Manager/trigger_word_highlight.js"; + +app.registerExtension({ + name: "LoraManager.LoraLoaderRemote", + + setup() { + api.addEventListener("lora_code_update", (event) => { + this.handleLoraCodeUpdate(event.detail || {}); + }); + }, + + handleLoraCodeUpdate(message) { + const nodeId = message?.node_id ?? message?.id; + const graphId = message?.graph_id; + const loraCode = message?.lora_code ?? ""; + const mode = message?.mode ?? "append"; + const numericNodeId = typeof nodeId === "string" ? Number(nodeId) : nodeId; + + if (numericNodeId === -1) { + const loraLoaderNodes = getAllGraphNodes(app.graph) + .map(({ node }) => node) + .filter((node) => node?.comfyClass === "Lora Loader (Remote, LoraManager)"); + + if (loraLoaderNodes.length > 0) { + loraLoaderNodes.forEach((node) => { + this.updateNodeLoraCode(node, loraCode, mode); + }); + } + return; + } + + const node = getNodeFromGraph(graphId, numericNodeId); + if ( + !node || + (node.comfyClass !== "Lora Loader (Remote, LoraManager)" && + node.comfyClass !== "Lora Stacker (Remote, LoraManager)" && + node.comfyClass !== "WanVideo Lora Select (Remote, LoraManager)") + ) { + return; + } + this.updateNodeLoraCode(node, loraCode, mode); + }, + + updateNodeLoraCode(node, loraCode, mode) { + const inputWidget = node.inputWidget; + if (!inputWidget) return; + + const currentValue = inputWidget.value || ""; + if (mode === "replace") { + inputWidget.value = loraCode; + } else { + inputWidget.value = currentValue.trim() + ? `${currentValue.trim()} ${loraCode}` + : loraCode; + } + + if (typeof inputWidget.callback === "function") { + inputWidget.callback(inputWidget.value); + } + }, + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeType.comfyClass === "Lora Loader (Remote, LoraManager)") { + chainCallback(nodeType.prototype, "onNodeCreated", function () { + this.serialize_widgets = true; + + this.addInput("clip", "CLIP", { shape: 7 }); + this.addInput("lora_stack", "LORA_STACK", { shape: 7 }); + + let isUpdating = false; + let isSyncingInput = false; + + const self = this; + let _mode = this.mode; + Object.defineProperty(this, "mode", { + get() { return _mode; }, + set(value) { + const oldValue = _mode; + _mode = value; + if (self.onModeChange) self.onModeChange(value, oldValue); + }, + }); + + this.onModeChange = function (newMode) { + const allActiveLoraNames = collectActiveLorasFromChain(self); + updateConnectedTriggerWords(self, allActiveLoraNames); + }; + + const inputWidget = this.widgets[0]; + this.inputWidget = inputWidget; + + const scheduleInputSync = debounce((lorasValue) => { + if (isSyncingInput) return; + isSyncingInput = true; + isUpdating = true; + try { + const nextText = applyLoraValuesToText(inputWidget.value, lorasValue); + if (inputWidget.value !== nextText) inputWidget.value = nextText; + } finally { + isUpdating = false; + isSyncingInput = false; + } + }); + + this.lorasWidget = addLorasWidget( + this, "loras", + { onSelectionChange: (selection) => applySelectionHighlight(this, selection) }, + (value) => { + if (isUpdating) return; + isUpdating = true; + try { + const allActiveLoraNames = collectActiveLorasFromChain(this); + updateConnectedTriggerWords(this, allActiveLoraNames); + } finally { + isUpdating = false; + } + scheduleInputSync(value); + } + ).widget; + + inputWidget.callback = (value) => { + if (isUpdating) return; + isUpdating = true; + try { + const currentLoras = this.lorasWidget.value || []; + const mergedLoras = mergeLoras(value, currentLoras); + this.lorasWidget.value = mergedLoras; + const allActiveLoraNames = collectActiveLorasFromChain(this); + updateConnectedTriggerWords(this, allActiveLoraNames); + } finally { + isUpdating = false; + } + }; + }); + } + }, + + async loadedGraphNode(node) { + if (node.comfyClass === "Lora Loader (Remote, LoraManager)") { + let existingLoras = []; + if (node.widgets_values && node.widgets_values.length > 0) { + existingLoras = node.widgets_values[1] || []; + } + const mergedLoras = mergeLoras(node.widgets[0].value, existingLoras); + node.lorasWidget.value = mergedLoras; + } + }, +}); diff --git a/web/comfyui/lora_stacker_remote.js b/web/comfyui/lora_stacker_remote.js new file mode 100644 index 0000000..4ab3cfc --- /dev/null +++ b/web/comfyui/lora_stacker_remote.js @@ -0,0 +1,97 @@ +/** + * JS shim for Remote Lora Stacker node. + */ +import { app } from "../../scripts/app.js"; +import { + getActiveLorasFromNode, + updateConnectedTriggerWords, + updateDownstreamLoaders, + chainCallback, + mergeLoras, +} from "/extensions/ComfyUI-Lora-Manager/utils.js"; +import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js"; +import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js"; +import { applySelectionHighlight } from "/extensions/ComfyUI-Lora-Manager/trigger_word_highlight.js"; + +app.registerExtension({ + name: "LoraManager.LoraStackerRemote", + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeType.comfyClass === "Lora Stacker (Remote, LoraManager)") { + chainCallback(nodeType.prototype, "onNodeCreated", async function () { + this.serialize_widgets = true; + this.addInput("lora_stack", "LORA_STACK", { shape: 7 }); + + let isUpdating = false; + let isSyncingInput = false; + + const inputWidget = this.widgets[0]; + this.inputWidget = inputWidget; + + const scheduleInputSync = debounce((lorasValue) => { + if (isSyncingInput) return; + isSyncingInput = true; + isUpdating = true; + try { + const nextText = applyLoraValuesToText(inputWidget.value, lorasValue); + if (inputWidget.value !== nextText) inputWidget.value = nextText; + } finally { + isUpdating = false; + isSyncingInput = false; + } + }); + + const result = addLorasWidget( + this, "loras", + { onSelectionChange: (selection) => applySelectionHighlight(this, selection) }, + (value) => { + if (isUpdating) return; + isUpdating = true; + try { + const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3; + const activeLoraNames = new Set(); + if (isNodeActive) { + value.forEach((lora) => { if (lora.active) activeLoraNames.add(lora.name); }); + } + updateConnectedTriggerWords(this, activeLoraNames); + updateDownstreamLoaders(this); + } finally { + isUpdating = false; + } + scheduleInputSync(value); + } + ); + + this.lorasWidget = result.widget; + + inputWidget.callback = (value) => { + if (isUpdating) return; + isUpdating = true; + try { + const currentLoras = this.lorasWidget?.value || []; + const mergedLoras = mergeLoras(value, currentLoras); + if (this.lorasWidget) this.lorasWidget.value = mergedLoras; + const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3; + const activeLoraNames = isNodeActive ? getActiveLorasFromNode(this) : new Set(); + updateConnectedTriggerWords(this, activeLoraNames); + updateDownstreamLoaders(this); + } finally { + isUpdating = false; + } + }; + }); + } + }, + + async loadedGraphNode(node) { + if (node.comfyClass === "Lora Stacker (Remote, LoraManager)") { + let existingLoras = []; + if (node.widgets_values && node.widgets_values.length > 0) { + existingLoras = node.widgets_values[1] || []; + } + const inputWidget = node.inputWidget || node.widgets[0]; + const mergedLoras = mergeLoras(inputWidget.value, existingLoras); + node.lorasWidget.value = mergedLoras; + } + }, +}); diff --git a/web/comfyui/wanvideo_remote.js b/web/comfyui/wanvideo_remote.js new file mode 100644 index 0000000..9187459 --- /dev/null +++ b/web/comfyui/wanvideo_remote.js @@ -0,0 +1,89 @@ +/** + * JS shim for Remote WanVideo Lora Select node. + */ +import { app } from "../../scripts/app.js"; +import { + getActiveLorasFromNode, + updateConnectedTriggerWords, + chainCallback, + mergeLoras, +} from "/extensions/ComfyUI-Lora-Manager/utils.js"; +import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js"; +import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js"; + +app.registerExtension({ + name: "LoraManager.WanVideoLoraSelectRemote", + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeType.comfyClass === "WanVideo Lora Select (Remote, LoraManager)") { + chainCallback(nodeType.prototype, "onNodeCreated", async function () { + this.serialize_widgets = true; + + this.addInput("prev_lora", "WANVIDLORA", { shape: 7 }); + this.addInput("blocks", "SELECTEDBLOCKS", { shape: 7 }); + + let isUpdating = false; + let isSyncingInput = false; + + // text widget is at index 2 (after low_mem_load, merge_loras) + const inputWidget = this.widgets[2]; + this.inputWidget = inputWidget; + + const scheduleInputSync = debounce((lorasValue) => { + if (isSyncingInput) return; + isSyncingInput = true; + isUpdating = true; + try { + const nextText = applyLoraValuesToText(inputWidget.value, lorasValue); + if (inputWidget.value !== nextText) inputWidget.value = nextText; + } finally { + isUpdating = false; + isSyncingInput = false; + } + }); + + const result = addLorasWidget(this, "loras", {}, (value) => { + if (isUpdating) return; + isUpdating = true; + try { + const activeLoraNames = new Set(); + value.forEach((lora) => { if (lora.active) activeLoraNames.add(lora.name); }); + updateConnectedTriggerWords(this, activeLoraNames); + } finally { + isUpdating = false; + } + scheduleInputSync(value); + }); + + this.lorasWidget = result.widget; + + inputWidget.callback = (value) => { + if (isUpdating) return; + isUpdating = true; + try { + const currentLoras = this.lorasWidget?.value || []; + const mergedLoras = mergeLoras(value, currentLoras); + if (this.lorasWidget) this.lorasWidget.value = mergedLoras; + const activeLoraNames = getActiveLorasFromNode(this); + updateConnectedTriggerWords(this, activeLoraNames); + } finally { + isUpdating = false; + } + }; + }); + } + }, + + async loadedGraphNode(node) { + if (node.comfyClass === "WanVideo Lora Select (Remote, LoraManager)") { + let existingLoras = []; + if (node.widgets_values && node.widgets_values.length > 0) { + // 0=low_mem_load, 1=merge_loras, 2=text, 3=loras + existingLoras = node.widgets_values[3] || []; + } + const inputWidget = node.inputWidget || node.widgets[2]; + const mergedLoras = mergeLoras(inputWidget.value, existingLoras); + node.lorasWidget.value = mergedLoras; + } + }, +});