feat: initial release of ComfyUI-LM-Remote
Remote-aware LoRA Manager nodes that fetch metadata via HTTP from a remote Docker instance while loading LoRA files from local NFS/SMB mounts. Includes reverse-proxy middleware for transparent web UI access. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
54
__init__.py
Normal file
54
__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
ComfyUI-LM-Remote — Remote LoRA Manager integration for ComfyUI.
|
||||
|
||||
Provides:
|
||||
1. A reverse-proxy middleware that forwards all LoRA Manager API/UI/WS
|
||||
requests to a remote Docker instance.
|
||||
2. Remote-aware node classes that fetch metadata via HTTP instead of the
|
||||
local ServiceRegistry, while still loading LoRA files from local
|
||||
NFS/SMB-mounted paths.
|
||||
|
||||
Requires the original ComfyUI-Lora-Manager package to be installed alongside
|
||||
for its widget JS files and custom widget types.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Import node classes ────────────────────────────────────────────────
|
||||
from .nodes.lora_loader import LoraLoaderRemoteLM, LoraTextLoaderRemoteLM
|
||||
from .nodes.lora_stacker import LoraStackerRemoteLM
|
||||
from .nodes.lora_randomizer import LoraRandomizerRemoteLM
|
||||
from .nodes.lora_cycler import LoraCyclerRemoteLM
|
||||
from .nodes.lora_pool import LoraPoolRemoteLM
|
||||
from .nodes.save_image import SaveImageRemoteLM
|
||||
from .nodes.wanvideo import WanVideoLoraSelectRemoteLM, WanVideoLoraTextSelectRemoteLM
|
||||
|
||||
# ── NODE_CLASS_MAPPINGS (how ComfyUI discovers nodes) ──────────────────
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
LoraLoaderRemoteLM.NAME: LoraLoaderRemoteLM,
|
||||
LoraTextLoaderRemoteLM.NAME: LoraTextLoaderRemoteLM,
|
||||
LoraStackerRemoteLM.NAME: LoraStackerRemoteLM,
|
||||
LoraRandomizerRemoteLM.NAME: LoraRandomizerRemoteLM,
|
||||
LoraCyclerRemoteLM.NAME: LoraCyclerRemoteLM,
|
||||
LoraPoolRemoteLM.NAME: LoraPoolRemoteLM,
|
||||
SaveImageRemoteLM.NAME: SaveImageRemoteLM,
|
||||
WanVideoLoraSelectRemoteLM.NAME: WanVideoLoraSelectRemoteLM,
|
||||
WanVideoLoraTextSelectRemoteLM.NAME: WanVideoLoraTextSelectRemoteLM,
|
||||
}
|
||||
|
||||
# ── WEB_DIRECTORY tells ComfyUI where to find our JS extensions ───────
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
|
||||
# ── Register proxy middleware ──────────────────────────────────────────
|
||||
try:
|
||||
from server import PromptServer # type: ignore
|
||||
from .proxy import register_proxy
|
||||
|
||||
register_proxy(PromptServer.instance.app)
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] Could not register proxy middleware: %s", exc)
|
||||
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]
|
||||
5
config.json
Normal file
5
config.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"remote_url": "http://192.168.1.100:8188",
|
||||
"timeout": 30,
|
||||
"path_mappings": {}
|
||||
}
|
||||
62
config.py
Normal file
62
config.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Configuration for ComfyUI-LM-Remote."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
_CONFIG_FILE = _PACKAGE_DIR / "config.json"
|
||||
|
||||
|
||||
class RemoteConfig:
|
||||
"""Holds remote LoRA Manager connection settings."""
|
||||
|
||||
def __init__(self):
|
||||
self.remote_url: str = ""
|
||||
self.timeout: int = 30
|
||||
self.path_mappings: dict[str, str] = {}
|
||||
self._load()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def _load(self):
|
||||
# Environment variable takes priority
|
||||
env_url = os.environ.get("LM_REMOTE_URL", "")
|
||||
env_timeout = os.environ.get("LM_REMOTE_TIMEOUT", "")
|
||||
|
||||
# Load config.json defaults
|
||||
if _CONFIG_FILE.exists():
|
||||
try:
|
||||
with open(_CONFIG_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self.remote_url = data.get("remote_url", "")
|
||||
self.timeout = int(data.get("timeout", 30))
|
||||
self.path_mappings = data.get("path_mappings", {})
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] Failed to read config.json: %s", exc)
|
||||
|
||||
# Env overrides
|
||||
if env_url:
|
||||
self.remote_url = env_url
|
||||
if env_timeout:
|
||||
self.timeout = int(env_timeout)
|
||||
|
||||
# Strip trailing slash
|
||||
self.remote_url = self.remote_url.rstrip("/")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.remote_url)
|
||||
|
||||
def map_path(self, remote_path: str) -> str:
|
||||
"""Apply remote->local path prefix mappings."""
|
||||
for remote_prefix, local_prefix in self.path_mappings.items():
|
||||
if remote_path.startswith(remote_prefix):
|
||||
return local_prefix + remote_path[len(remote_prefix):]
|
||||
return remote_path
|
||||
|
||||
|
||||
remote_config = RemoteConfig()
|
||||
0
nodes/__init__.py
Normal file
0
nodes/__init__.py
Normal file
92
nodes/lora_cycler.py
Normal file
92
nodes/lora_cycler.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Remote LoRA Cycler — uses remote API instead of local ServiceRegistry."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .remote_utils import get_lora_info_remote
|
||||
from ..remote_client import RemoteLoraClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraCyclerRemoteLM:
|
||||
"""Node that sequentially cycles through LoRAs from a pool (remote)."""
|
||||
|
||||
NAME = "Lora Cycler (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/randomizer"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"cycler_config": ("CYCLER_CONFIG", {}),
|
||||
},
|
||||
"optional": {
|
||||
"pool_config": ("POOL_CONFIG", {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK",)
|
||||
RETURN_NAMES = ("LORA_STACK",)
|
||||
FUNCTION = "cycle"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
async def cycle(self, cycler_config, pool_config=None):
|
||||
current_index = cycler_config.get("current_index", 1)
|
||||
model_strength = float(cycler_config.get("model_strength", 1.0))
|
||||
clip_strength = float(cycler_config.get("clip_strength", 1.0))
|
||||
execution_index = cycler_config.get("execution_index")
|
||||
|
||||
client = RemoteLoraClient.get_instance()
|
||||
lora_list = await client.get_cycler_list(
|
||||
pool_config=pool_config, sort_by="filename"
|
||||
)
|
||||
|
||||
total_count = len(lora_list)
|
||||
|
||||
if total_count == 0:
|
||||
logger.warning("[LoraCyclerRemoteLM] No LoRAs available in pool")
|
||||
return {
|
||||
"result": ([],),
|
||||
"ui": {
|
||||
"current_index": [1],
|
||||
"next_index": [1],
|
||||
"total_count": [0],
|
||||
"current_lora_name": [""],
|
||||
"current_lora_filename": [""],
|
||||
"error": ["No LoRAs available in pool"],
|
||||
},
|
||||
}
|
||||
|
||||
actual_index = execution_index if execution_index is not None else current_index
|
||||
clamped_index = max(1, min(actual_index, total_count))
|
||||
|
||||
current_lora = lora_list[clamped_index - 1]
|
||||
|
||||
lora_path, _ = get_lora_info_remote(current_lora["file_name"])
|
||||
if not lora_path:
|
||||
logger.warning("[LoraCyclerRemoteLM] Could not find path for LoRA: %s", current_lora["file_name"])
|
||||
lora_stack = []
|
||||
else:
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
lora_stack = [(lora_path, model_strength, clip_strength)]
|
||||
|
||||
next_index = clamped_index + 1
|
||||
if next_index > total_count:
|
||||
next_index = 1
|
||||
|
||||
next_lora = lora_list[next_index - 1]
|
||||
|
||||
return {
|
||||
"result": (lora_stack,),
|
||||
"ui": {
|
||||
"current_index": [clamped_index],
|
||||
"next_index": [next_index],
|
||||
"total_count": [total_count],
|
||||
"current_lora_name": [current_lora["file_name"]],
|
||||
"current_lora_filename": [current_lora["file_name"]],
|
||||
"next_lora_name": [next_lora["file_name"]],
|
||||
"next_lora_filename": [next_lora["file_name"]],
|
||||
},
|
||||
}
|
||||
205
nodes/lora_loader.py
Normal file
205
nodes/lora_loader.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Remote LoRA Loader nodes — fetch metadata from the remote LoRA Manager."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from nodes import LoraLoader # type: ignore
|
||||
|
||||
from .remote_utils import get_lora_info_remote
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraLoaderRemoteLM:
|
||||
NAME = "Lora Loader (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
||||
"placeholder": "Search LoRAs to add...",
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras"
|
||||
|
||||
def load_loras(self, model, text, **kwargs):
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
clip = kwargs.get("clip", None)
|
||||
lora_stack = kwargs.get("lora_stack", None)
|
||||
|
||||
is_nunchaku_model = False
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Process lora_stack
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
if is_nunchaku_model:
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info_remote(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Process loras from widget
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
lora_name = lora["name"]
|
||||
model_strength = float(lora["strength"])
|
||||
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||
|
||||
lora_path, trigger_words = get_lora_info_remote(lora_name)
|
||||
|
||||
if is_nunchaku_model:
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0]
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
if len(strength_parts) > 1:
|
||||
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}:{strength_parts[1].strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
|
||||
class LoraTextLoaderRemoteLM:
|
||||
NAME = "LoRA Text Loader (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": ("STRING", {
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
loras = []
|
||||
for match in matches:
|
||||
loras.append({
|
||||
"name": match[0],
|
||||
"model_strength": float(match[1]),
|
||||
"clip_strength": float(match[2]) if match[2] else float(match[1]),
|
||||
})
|
||||
return loras
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
is_nunchaku_model = False
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
if is_nunchaku_model:
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info_remote(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora["name"]
|
||||
model_strength = lora["model_strength"]
|
||||
clip_strength = lora["clip_strength"]
|
||||
|
||||
lora_path, trigger_words = get_lora_info_remote(lora_name)
|
||||
|
||||
if is_nunchaku_model:
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0].strip()
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
if len(strength_parts) > 1:
|
||||
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}:{strength_parts[1].strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{lora_name}:{strength_parts[0].strip()}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
55
nodes/lora_pool.py
Normal file
55
nodes/lora_pool.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Remote LoRA Pool — pure pass-through, just different NAME."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraPoolRemoteLM:
|
||||
"""LoRA Pool that passes through filter config (remote variant for NAME only)."""
|
||||
|
||||
NAME = "Lora Pool (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/randomizer"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"pool_config": ("LORA_POOL_CONFIG", {}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("POOL_CONFIG",)
|
||||
RETURN_NAMES = ("POOL_CONFIG",)
|
||||
FUNCTION = "process"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def process(self, pool_config, unique_id=None):
|
||||
if not isinstance(pool_config, dict):
|
||||
logger.warning("Invalid pool_config type, using empty config")
|
||||
pool_config = self._default_config()
|
||||
|
||||
if "version" not in pool_config:
|
||||
pool_config["version"] = 1
|
||||
|
||||
filters = pool_config.get("filters", self._default_config()["filters"])
|
||||
logger.debug("[LoraPoolRemoteLM] Processing filters: %s", filters)
|
||||
return (filters,)
|
||||
|
||||
@staticmethod
|
||||
def _default_config():
|
||||
return {
|
||||
"version": 1,
|
||||
"filters": {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"favoritesOnly": False,
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
},
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||
}
|
||||
113
nodes/lora_randomizer.py
Normal file
113
nodes/lora_randomizer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Remote LoRA Randomizer — uses remote API instead of local ServiceRegistry."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .remote_utils import get_lora_info_remote
|
||||
from .utils import extract_lora_name
|
||||
from ..remote_client import RemoteLoraClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraRandomizerRemoteLM:
|
||||
"""Node that randomly selects LoRAs from a pool (remote metadata)."""
|
||||
|
||||
NAME = "Lora Randomizer (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/randomizer"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"randomizer_config": ("RANDOMIZER_CONFIG", {}),
|
||||
"loras": ("LORAS", {}),
|
||||
},
|
||||
"optional": {
|
||||
"pool_config": ("POOL_CONFIG", {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK",)
|
||||
RETURN_NAMES = ("LORA_STACK",)
|
||||
FUNCTION = "randomize"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def _preprocess_loras_input(self, loras):
|
||||
if isinstance(loras, dict) and "__value__" in loras:
|
||||
return loras["__value__"]
|
||||
return loras
|
||||
|
||||
async def randomize(self, randomizer_config, loras, pool_config=None):
|
||||
loras = self._preprocess_loras_input(loras)
|
||||
|
||||
roll_mode = randomizer_config.get("roll_mode", "always")
|
||||
execution_seed = randomizer_config.get("execution_seed", None)
|
||||
next_seed = randomizer_config.get("next_seed", None)
|
||||
|
||||
if roll_mode == "fixed":
|
||||
ui_loras = loras
|
||||
execution_loras = loras
|
||||
else:
|
||||
client = RemoteLoraClient.get_instance()
|
||||
|
||||
# Build common kwargs for remote API
|
||||
api_kwargs = self._build_api_kwargs(randomizer_config, loras, pool_config)
|
||||
|
||||
if execution_seed is not None:
|
||||
exec_kwargs = {**api_kwargs, "seed": execution_seed}
|
||||
execution_loras = await client.get_random_loras(**exec_kwargs)
|
||||
if not execution_loras:
|
||||
execution_loras = loras
|
||||
else:
|
||||
execution_loras = loras
|
||||
|
||||
ui_kwargs = {**api_kwargs, "seed": next_seed}
|
||||
ui_loras = await client.get_random_loras(**ui_kwargs)
|
||||
if not ui_loras:
|
||||
ui_loras = loras
|
||||
|
||||
execution_stack = self._build_execution_stack_from_input(execution_loras)
|
||||
|
||||
return {
|
||||
"result": (execution_stack,),
|
||||
"ui": {"loras": ui_loras, "last_used": execution_loras},
|
||||
}
|
||||
|
||||
def _build_api_kwargs(self, randomizer_config, input_loras, pool_config):
|
||||
locked_loras = [l for l in input_loras if l.get("locked", False)]
|
||||
return {
|
||||
"count": int(randomizer_config.get("count_fixed", 5)),
|
||||
"count_mode": randomizer_config.get("count_mode", "range"),
|
||||
"count_min": int(randomizer_config.get("count_min", 3)),
|
||||
"count_max": int(randomizer_config.get("count_max", 7)),
|
||||
"model_strength_min": float(randomizer_config.get("model_strength_min", 0.0)),
|
||||
"model_strength_max": float(randomizer_config.get("model_strength_max", 1.0)),
|
||||
"use_same_clip_strength": randomizer_config.get("use_same_clip_strength", True),
|
||||
"clip_strength_min": float(randomizer_config.get("clip_strength_min", 0.0)),
|
||||
"clip_strength_max": float(randomizer_config.get("clip_strength_max", 1.0)),
|
||||
"use_recommended_strength": randomizer_config.get("use_recommended_strength", False),
|
||||
"recommended_strength_scale_min": float(randomizer_config.get("recommended_strength_scale_min", 0.5)),
|
||||
"recommended_strength_scale_max": float(randomizer_config.get("recommended_strength_scale_max", 1.0)),
|
||||
"locked_loras": locked_loras,
|
||||
"pool_config": pool_config,
|
||||
}
|
||||
|
||||
def _build_execution_stack_from_input(self, loras):
|
||||
lora_stack = []
|
||||
for lora in loras:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
lora_path, _ = get_lora_info_remote(lora["name"])
|
||||
if not lora_path:
|
||||
logger.warning("[LoraRandomizerRemoteLM] Could not find path for LoRA: %s", lora["name"])
|
||||
continue
|
||||
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
model_strength = float(lora.get("strength", 1.0))
|
||||
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
return lora_stack
|
||||
71
nodes/lora_stacker.py
Normal file
71
nodes/lora_stacker.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Remote LoRA Stacker — fetch metadata from the remote LoRA Manager."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .remote_utils import get_lora_info_remote
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraStackerRemoteLM:
|
||||
NAME = "Lora Stacker (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
||||
"placeholder": "Search LoRAs to add...",
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK", "STRING", "STRING")
|
||||
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
||||
FUNCTION = "stack_loras"
|
||||
|
||||
def stack_loras(self, text, **kwargs):
|
||||
stack = []
|
||||
active_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
lora_stack = kwargs.get("lora_stack", None)
|
||||
if lora_stack:
|
||||
stack.extend(lora_stack)
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info_remote(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
lora_name = lora["name"]
|
||||
model_strength = float(lora["strength"])
|
||||
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||
|
||||
lora_path, trigger_words = get_lora_info_remote(lora_name)
|
||||
|
||||
stack.append((lora_path.replace("/", os.sep), model_strength, clip_strength))
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
return (stack, trigger_words_text, active_loras_text)
|
||||
44
nodes/remote_utils.py
Normal file
44
nodes/remote_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Remote replacement for ``py/utils/utils.py:get_lora_info()``.
|
||||
|
||||
Same signature: ``get_lora_info_remote(lora_name) -> (relative_path, trigger_words)``
|
||||
but fetches data from the remote LoRA Manager HTTP API instead of the local
|
||||
ServiceRegistry / SQLite cache.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from ..remote_client import RemoteLoraClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_lora_info_remote(lora_name: str) -> tuple[str, list[str]]:
|
||||
"""Synchronous wrapper that calls the remote API for LoRA metadata.
|
||||
|
||||
Uses the same sync-from-async bridge pattern as the original
|
||||
``get_lora_info()`` to be a drop-in replacement in node ``FUNCTION`` methods.
|
||||
"""
|
||||
async def _fetch():
|
||||
client = RemoteLoraClient.get_instance()
|
||||
return await client.get_lora_info(lora_name)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
# Already inside an event loop — run in a separate thread.
|
||||
def _run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(_fetch())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(_run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
# No running loop — safe to use asyncio.run()
|
||||
return asyncio.run(_fetch())
|
||||
381
nodes/save_image.py
Normal file
381
nodes/save_image.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Remote Save Image — uses remote API for hash lookups."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import folder_paths # type: ignore
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from ..remote_client import RemoteLoraClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import piexif
|
||||
except ImportError:
|
||||
piexif = None
|
||||
|
||||
|
||||
class SaveImageRemoteLM:
|
||||
NAME = "Save Image (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Save images with embedded generation metadata (remote hash lookup)"
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
self.compress_level = 4
|
||||
self.counter = 0
|
||||
|
||||
pattern_format = re.compile(r"(%[^%]+%)")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"filename_prefix": ("STRING", {
|
||||
"default": "ComfyUI",
|
||||
"tooltip": "Base filename. Supports %seed%, %width%, %height%, %model%, etc.",
|
||||
}),
|
||||
"file_format": (["png", "jpeg", "webp"], {
|
||||
"tooltip": "Image format to save as.",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"lossless_webp": ("BOOLEAN", {"default": False}),
|
||||
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
|
||||
"embed_workflow": ("BOOLEAN", {"default": False}),
|
||||
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "process_image"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Remote hash lookups
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_async(self, coro):
|
||||
"""Run an async coroutine from sync context."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
|
||||
def _in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(_in_thread).result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
def get_lora_hash(self, lora_name):
|
||||
client = RemoteLoraClient.get_instance()
|
||||
return self._run_async(client.get_lora_hash(lora_name))
|
||||
|
||||
def get_checkpoint_hash(self, checkpoint_path):
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
||||
client = RemoteLoraClient.get_instance()
|
||||
return self._run_async(client.get_checkpoint_hash(checkpoint_name))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Metadata formatting (identical to original)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_metadata(self, metadata_dict):
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
|
||||
def add_param_if_not_none(param_list, label, value):
|
||||
if value is not None:
|
||||
param_list.append(f"{label}: {value}")
|
||||
|
||||
prompt = metadata_dict.get("prompt", "")
|
||||
negative_prompt = metadata_dict.get("negative_prompt", "")
|
||||
loras_text = metadata_dict.get("loras", "")
|
||||
lora_hashes = {}
|
||||
|
||||
if loras_text:
|
||||
prompt_with_loras = f"{prompt}\n{loras_text}"
|
||||
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
|
||||
for lora_name, _ in lora_matches:
|
||||
hash_value = self.get_lora_hash(lora_name)
|
||||
if hash_value:
|
||||
lora_hashes[lora_name] = hash_value
|
||||
else:
|
||||
prompt_with_loras = prompt
|
||||
|
||||
metadata_parts = [prompt_with_loras]
|
||||
|
||||
if negative_prompt:
|
||||
metadata_parts.append(f"Negative prompt: {negative_prompt}")
|
||||
|
||||
params = []
|
||||
|
||||
if "steps" in metadata_dict:
|
||||
add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
|
||||
|
||||
sampler_name = None
|
||||
scheduler_name = None
|
||||
|
||||
if "sampler" in metadata_dict:
|
||||
sampler_mapping = {
|
||||
"euler": "Euler", "euler_ancestral": "Euler a",
|
||||
"dpm_2": "DPM2", "dpm_2_ancestral": "DPM2 a",
|
||||
"heun": "Heun", "dpm_fast": "DPM fast",
|
||||
"dpm_adaptive": "DPM adaptive", "lms": "LMS",
|
||||
"dpmpp_2s_ancestral": "DPM++ 2S a", "dpmpp_sde": "DPM++ SDE",
|
||||
"dpmpp_sde_gpu": "DPM++ SDE", "dpmpp_2m": "DPM++ 2M",
|
||||
"dpmpp_2m_sde": "DPM++ 2M SDE", "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
|
||||
"ddim": "DDIM",
|
||||
}
|
||||
sampler_name = sampler_mapping.get(metadata_dict["sampler"], metadata_dict["sampler"])
|
||||
|
||||
if "scheduler" in metadata_dict:
|
||||
scheduler_mapping = {
|
||||
"normal": "Simple", "karras": "Karras",
|
||||
"exponential": "Exponential", "sgm_uniform": "SGM Uniform",
|
||||
"sgm_quadratic": "SGM Quadratic",
|
||||
}
|
||||
scheduler_name = scheduler_mapping.get(metadata_dict["scheduler"], metadata_dict["scheduler"])
|
||||
|
||||
if sampler_name:
|
||||
if scheduler_name:
|
||||
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
||||
else:
|
||||
params.append(f"Sampler: {sampler_name}")
|
||||
|
||||
if "guidance" in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
|
||||
elif "cfg_scale" in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
|
||||
elif "cfg" in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
|
||||
|
||||
if "seed" in metadata_dict:
|
||||
add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
|
||||
|
||||
if "size" in metadata_dict:
|
||||
add_param_if_not_none(params, "Size", metadata_dict.get("size"))
|
||||
|
||||
if "checkpoint" in metadata_dict:
|
||||
checkpoint = metadata_dict.get("checkpoint")
|
||||
if checkpoint is not None:
|
||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||
checkpoint_name = os.path.splitext(os.path.basename(checkpoint))[0]
|
||||
if model_hash:
|
||||
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
||||
else:
|
||||
params.append(f"Model: {checkpoint_name}")
|
||||
|
||||
if lora_hashes:
|
||||
lora_hash_parts = [f"{n}: {h[:10]}" for n, h in lora_hashes.items()]
|
||||
params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
|
||||
|
||||
metadata_parts.append(", ".join(params))
|
||||
return "\n".join(metadata_parts)
|
||||
|
||||
def format_filename(self, filename, metadata_dict):
|
||||
if not metadata_dict:
|
||||
return filename
|
||||
|
||||
result = re.findall(self.pattern_format, filename)
|
||||
for segment in result:
|
||||
parts = segment.replace("%", "").split(":")
|
||||
key = parts[0]
|
||||
|
||||
if key == "seed" and "seed" in metadata_dict:
|
||||
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
|
||||
elif key == "width" and "size" in metadata_dict:
|
||||
size = metadata_dict.get("size", "x")
|
||||
w = size.split("x")[0] if isinstance(size, str) else size[0]
|
||||
filename = filename.replace(segment, str(w))
|
||||
elif key == "height" and "size" in metadata_dict:
|
||||
size = metadata_dict.get("size", "x")
|
||||
h = size.split("x")[1] if isinstance(size, str) else size[1]
|
||||
filename = filename.replace(segment, str(h))
|
||||
elif key == "pprompt" and "prompt" in metadata_dict:
|
||||
p = metadata_dict.get("prompt", "").replace("\n", " ")
|
||||
if len(parts) >= 2:
|
||||
p = p[: int(parts[1])]
|
||||
filename = filename.replace(segment, p.strip())
|
||||
elif key == "nprompt" and "negative_prompt" in metadata_dict:
|
||||
p = metadata_dict.get("negative_prompt", "").replace("\n", " ")
|
||||
if len(parts) >= 2:
|
||||
p = p[: int(parts[1])]
|
||||
filename = filename.replace(segment, p.strip())
|
||||
elif key == "model":
|
||||
model_value = metadata_dict.get("checkpoint")
|
||||
if isinstance(model_value, (bytes, os.PathLike)):
|
||||
model_value = str(model_value)
|
||||
if not isinstance(model_value, str) or not model_value:
|
||||
model = "model_unavailable"
|
||||
else:
|
||||
model = os.path.splitext(os.path.basename(model_value))[0]
|
||||
if len(parts) >= 2:
|
||||
model = model[: int(parts[1])]
|
||||
filename = filename.replace(segment, model)
|
||||
elif key == "date":
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
date_table = {
|
||||
"yyyy": f"{now.year:04d}", "yy": f"{now.year % 100:02d}",
|
||||
"MM": f"{now.month:02d}", "dd": f"{now.day:02d}",
|
||||
"hh": f"{now.hour:02d}", "mm": f"{now.minute:02d}",
|
||||
"ss": f"{now.second:02d}",
|
||||
}
|
||||
if len(parts) >= 2:
|
||||
date_format = parts[1]
|
||||
else:
|
||||
date_format = "yyyyMMddhhmmss"
|
||||
for k, v in date_table.items():
|
||||
date_format = date_format.replace(k, v)
|
||||
filename = filename.replace(segment, date_format)
|
||||
|
||||
return filename
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Image saving
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_images(self, images, filename_prefix, file_format, id, prompt=None,
|
||||
extra_pnginfo=None, lossless_webp=True, quality=100,
|
||||
embed_workflow=False, add_counter_to_filename=True):
|
||||
results = []
|
||||
|
||||
# Try to get metadata from the original LoRA Manager's collector.
|
||||
# The package directory name varies across installs, so we search
|
||||
# sys.modules for any loaded module whose path ends with the
|
||||
# expected submodule.
|
||||
metadata_dict = {}
|
||||
try:
|
||||
get_metadata = None
|
||||
MetadataProcessor = None
|
||||
|
||||
import sys
|
||||
for mod_name, mod in sys.modules.items():
|
||||
if mod is None:
|
||||
continue
|
||||
if mod_name.endswith(".py.metadata_collector") and hasattr(mod, "get_metadata"):
|
||||
get_metadata = mod.get_metadata
|
||||
if mod_name.endswith(".py.metadata_collector.metadata_processor") and hasattr(mod, "MetadataProcessor"):
|
||||
MetadataProcessor = mod.MetadataProcessor
|
||||
if get_metadata and MetadataProcessor:
|
||||
break
|
||||
|
||||
if get_metadata and MetadataProcessor:
|
||||
raw_metadata = get_metadata()
|
||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||
else:
|
||||
logger.debug("[LM-Remote] metadata_collector not found in loaded modules")
|
||||
except Exception:
|
||||
logger.debug("[LM-Remote] metadata_collector not available, saving without generation metadata")
|
||||
|
||||
metadata = self.format_metadata(metadata_dict)
|
||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
|
||||
if not os.path.exists(full_output_folder):
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
|
||||
for i, image in enumerate(images):
|
||||
img = 255.0 * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
|
||||
|
||||
base_filename = filename
|
||||
if add_counter_to_filename:
|
||||
current_counter = counter + i
|
||||
base_filename += f"_{current_counter:05}_"
|
||||
|
||||
if file_format == "png":
|
||||
file = base_filename + ".png"
|
||||
save_kwargs = {"compress_level": self.compress_level}
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
elif file_format == "jpeg":
|
||||
file = base_filename + ".jpg"
|
||||
save_kwargs = {"quality": quality, "optimize": True}
|
||||
elif file_format == "webp":
|
||||
file = base_filename + ".webp"
|
||||
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
|
||||
|
||||
file_path = os.path.join(full_output_folder, file)
|
||||
|
||||
try:
|
||||
if file_format == "png":
|
||||
if metadata:
|
||||
pnginfo.add_text("parameters", metadata)
|
||||
if embed_workflow and extra_pnginfo is not None:
|
||||
pnginfo.add_text("workflow", json.dumps(extra_pnginfo["workflow"]))
|
||||
save_kwargs["pnginfo"] = pnginfo
|
||||
img.save(file_path, format="PNG", **save_kwargs)
|
||||
elif file_format == "jpeg" and piexif:
|
||||
if metadata:
|
||||
try:
|
||||
exif_dict = {"Exif": {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}}
|
||||
save_kwargs["exif"] = piexif.dump(exif_dict)
|
||||
except Exception as e:
|
||||
logger.error("Error adding EXIF data: %s", e)
|
||||
img.save(file_path, format="JPEG", **save_kwargs)
|
||||
elif file_format == "webp" and piexif:
|
||||
try:
|
||||
exif_dict = {}
|
||||
if metadata:
|
||||
exif_dict["Exif"] = {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}
|
||||
if embed_workflow and extra_pnginfo is not None:
|
||||
exif_dict["0th"] = {piexif.ImageIFD.ImageDescription: "Workflow:" + json.dumps(extra_pnginfo["workflow"])}
|
||||
save_kwargs["exif"] = piexif.dump(exif_dict)
|
||||
except Exception as e:
|
||||
logger.error("Error adding EXIF data: %s", e)
|
||||
img.save(file_path, format="WEBP", **save_kwargs)
|
||||
else:
|
||||
img.save(file_path)
|
||||
|
||||
results.append({"filename": file, "subfolder": subfolder, "type": self.type})
|
||||
except Exception as e:
|
||||
logger.error("Error saving image: %s", e)
|
||||
|
||||
return results
|
||||
|
||||
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png",
|
||||
prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100,
|
||||
embed_workflow=False, add_counter_to_filename=True):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
if isinstance(images, (list, np.ndarray)):
|
||||
pass
|
||||
else:
|
||||
if len(images.shape) == 3:
|
||||
images = [images]
|
||||
else:
|
||||
images = [img for img in images]
|
||||
|
||||
self.save_images(
|
||||
images, filename_prefix, file_format, id, prompt, extra_pnginfo,
|
||||
lossless_webp, quality, embed_workflow, add_counter_to_filename,
|
||||
)
|
||||
return (images,)
|
||||
141
nodes/utils.py
Normal file
141
nodes/utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Minimal utility classes/functions copied from the original LoRA Manager.
|
||||
|
||||
Only the pieces needed by the remote node classes are included here so that
|
||||
ComfyUI-LM-Remote can function independently of the original package's Python
|
||||
internals (while still requiring its JS widget files).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import folder_paths # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnyType(str):
|
||||
"""A special class that is always equal in not-equal comparisons.
|
||||
|
||||
Credit to pythongosssss.
|
||||
"""
|
||||
|
||||
def __ne__(self, __value: object) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FlexibleOptionalInputType(dict):
|
||||
"""Allow flexible/dynamic input types on ComfyUI nodes.
|
||||
|
||||
Credit to Regis Gaughan, III (rgthree).
|
||||
"""
|
||||
|
||||
def __init__(self, type):
|
||||
self.type = type
|
||||
|
||||
def __getitem__(self, key):
|
||||
return (self.type,)
|
||||
|
||||
def __contains__(self, key):
|
||||
return True
|
||||
|
||||
|
||||
any_type = AnyType("*")
|
||||
|
||||
|
||||
def extract_lora_name(lora_path: str) -> str:
|
||||
"""``'IL\\\\aorunIllstrious.safetensors'`` -> ``'aorunIllstrious'``"""
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
|
||||
def get_loras_list(kwargs: dict) -> list:
|
||||
"""Extract loras list from either old or new kwargs format."""
|
||||
if "loras" not in kwargs:
|
||||
return []
|
||||
loras_data = kwargs["loras"]
|
||||
if isinstance(loras_data, dict) and "__value__" in loras_data:
|
||||
return loras_data["__value__"]
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
else:
|
||||
logger.warning("Unexpected loras format: %s", type(loras_data))
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Nunchaku LoRA helpers (copied verbatim from original)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
if filter_prefix and not k.startswith(filter_prefix):
|
||||
continue
|
||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
|
||||
def to_diffusers(input_lora):
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
|
||||
if isinstance(input_lora, str):
|
||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||
else:
|
||||
tensors = {k: v for k, v in input_lora.items()}
|
||||
|
||||
for k, v in tensors.items():
|
||||
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensors[k] = v.to(torch.bfloat16)
|
||||
|
||||
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||
return new_tensors
|
||||
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||
if not lora_path or not os.path.isfile(lora_path):
|
||||
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
||||
return model
|
||||
|
||||
model_wrapper = model.model.diffusion_model
|
||||
|
||||
module_name = model_wrapper.__class__.__module__
|
||||
module = sys.modules.get(module_name)
|
||||
copy_with_ctx = getattr(module, "copy_with_ctx", None)
|
||||
|
||||
if copy_with_ctx is not None:
|
||||
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
|
||||
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
|
||||
else:
|
||||
logger.warning(
|
||||
"Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. "
|
||||
"Falling back to legacy loading logic."
|
||||
)
|
||||
transformer = model_wrapper.model
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model)
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||
assert new_in_channels % 4 == 0
|
||||
new_in_channels = new_in_channels // 4
|
||||
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||
if old_in_channels < new_in_channels:
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
202
nodes/wanvideo.py
Normal file
202
nodes/wanvideo.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Remote WanVideo LoRA nodes — fetch metadata from the remote LoRA Manager."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import folder_paths # type: ignore
|
||||
|
||||
from .remote_utils import get_lora_info_remote
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WanVideoLoraSelectRemoteLM:
|
||||
NAME = "WanVideo Lora Select (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Load LORA models with less VRAM usage, slower loading.",
|
||||
}),
|
||||
"merge_loras": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Merge LoRAs into the model.",
|
||||
}),
|
||||
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
||||
"placeholder": "Search LoRAs to add...",
|
||||
"tooltip": "Format: <lora:lora_name:strength>",
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras"
|
||||
|
||||
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
prev_lora = kwargs.get("prev_lora", None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_loras:
|
||||
low_mem_load = False
|
||||
|
||||
blocks = kwargs.get("blocks", {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
loras_from_widget = get_loras_list(kwargs)
|
||||
for lora in loras_from_widget:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
lora_name = lora["name"]
|
||||
model_strength = float(lora["strength"])
|
||||
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||
|
||||
lora_path, trigger_words = get_lora_info_remote(lora_name)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_loras,
|
||||
}
|
||||
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, ms, cs in active_loras:
|
||||
if abs(ms - cs) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}:{str(cs).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
|
||||
|
||||
class WanVideoLoraTextSelectRemoteLM:
|
||||
NAME = "WanVideo Lora Select From Text (Remote, LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Load LORA models with less VRAM usage, slower loading.",
|
||||
}),
|
||||
"merge_lora": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Merge LoRAs into the model.",
|
||||
}),
|
||||
"lora_syntax": ("STRING", {
|
||||
"multiline": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_lora": ("WANVIDLORA",),
|
||||
"blocks": ("BLOCKS",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras_from_syntax"
|
||||
|
||||
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
|
||||
blocks = kwargs.get("blocks", {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
prev_lora = kwargs.get("prev_lora", None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_lora:
|
||||
low_mem_load = False
|
||||
|
||||
parts = lora_syntax.split("<lora:")
|
||||
for part in parts[1:]:
|
||||
end_index = part.find(">")
|
||||
if end_index == -1:
|
||||
continue
|
||||
|
||||
content = part[:end_index]
|
||||
lora_parts = content.split(":")
|
||||
|
||||
lora_name_raw = ""
|
||||
model_strength = 1.0
|
||||
clip_strength = 1.0
|
||||
|
||||
if len(lora_parts) == 2:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = model_strength
|
||||
except (ValueError, IndexError):
|
||||
logger.warning("Invalid strength for LoRA '%s'. Skipping.", lora_name_raw)
|
||||
continue
|
||||
elif len(lora_parts) >= 3:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = float(lora_parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning("Invalid strengths for LoRA '%s'. Skipping.", lora_name_raw)
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_path, trigger_words = get_lora_info_remote(lora_name_raw)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_lora,
|
||||
}
|
||||
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name_raw, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, ms, cs in active_loras:
|
||||
if abs(ms - cs) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}:{str(cs).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(ms).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
243
proxy.py
Normal file
243
proxy.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Reverse-proxy middleware that forwards LoRA Manager requests to the remote instance.
|
||||
|
||||
Registered as an aiohttp middleware on PromptServer.instance.app. It intercepts
|
||||
requests matching known LoRA Manager URL prefixes and proxies them to the remote
|
||||
Docker instance. Non-matching requests fall through to the regular ComfyUI router.
|
||||
|
||||
Routes that use ``PromptServer.instance.send_sync()`` are explicitly excluded
|
||||
from proxying so the local original LoRA Manager handler can broadcast events
|
||||
to the local ComfyUI frontend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
from .config import remote_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL prefixes that should be forwarded to the remote LoRA Manager
|
||||
# ---------------------------------------------------------------------------
|
||||
_PROXY_PREFIXES = (
|
||||
"/api/lm/",
|
||||
"/loras_static/",
|
||||
"/locales/",
|
||||
"/example_images_static/",
|
||||
)
|
||||
|
||||
# Page routes served by the standalone LoRA Manager web UI
|
||||
_PROXY_PAGE_ROUTES = {
|
||||
"/loras",
|
||||
"/checkpoints",
|
||||
"/embeddings",
|
||||
"/loras/recipes",
|
||||
"/statistics",
|
||||
}
|
||||
|
||||
# WebSocket endpoints to proxy
|
||||
_WS_ROUTES = {
|
||||
"/ws/fetch-progress",
|
||||
"/ws/download-progress",
|
||||
"/ws/init-progress",
|
||||
}
|
||||
|
||||
# Routes that call send_sync on the remote side — these are NOT proxied.
|
||||
# Instead they fall through to the local original LoRA Manager handler,
|
||||
# which broadcasts events to the local ComfyUI frontend. The remote
|
||||
# handler would broadcast to its own (empty) frontend, which is useless.
|
||||
#
|
||||
# These routes:
|
||||
# /api/lm/loras/get_trigger_words -> trigger_word_update event
|
||||
# /api/lm/update-lora-code -> lora_code_update event
|
||||
# /api/lm/update-node-widget -> lm_widget_update event
|
||||
# /api/lm/register-nodes -> lora_registry_refresh event
|
||||
_SEND_SYNC_SKIP_ROUTES = {
|
||||
"/api/lm/loras/get_trigger_words",
|
||||
"/api/lm/update-lora-code",
|
||||
"/api/lm/update-node-widget",
|
||||
"/api/lm/register-nodes",
|
||||
}
|
||||
|
||||
# Shared HTTP session for proxied requests (connection pooling)
|
||||
_proxy_session: aiohttp.ClientSession | None = None
|
||||
|
||||
|
||||
async def _get_proxy_session() -> aiohttp.ClientSession:
|
||||
"""Return a shared aiohttp session for HTTP proxy requests."""
|
||||
global _proxy_session
|
||||
if _proxy_session is None or _proxy_session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=remote_config.timeout)
|
||||
_proxy_session = aiohttp.ClientSession(timeout=timeout)
|
||||
return _proxy_session
|
||||
|
||||
|
||||
def _should_proxy(path: str) -> bool:
|
||||
"""Return True if *path* should be proxied to the remote instance."""
|
||||
if any(path.startswith(p) for p in _PROXY_PREFIXES):
|
||||
return True
|
||||
if path in _PROXY_PAGE_ROUTES or path.rstrip("/") in _PROXY_PAGE_ROUTES:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_ws_route(path: str) -> bool:
|
||||
return path in _WS_ROUTES
|
||||
|
||||
|
||||
async def _proxy_ws(request: web.Request) -> web.WebSocketResponse:
|
||||
"""Proxy a WebSocket connection to the remote LoRA Manager."""
|
||||
remote_url = remote_config.remote_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
remote_ws_url = f"{remote_url}{request.path}"
|
||||
if request.query_string:
|
||||
remote_ws_url += f"?{request.query_string}"
|
||||
|
||||
local_ws = web.WebSocketResponse()
|
||||
await local_ws.prepare(request)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
try:
|
||||
async with session.ws_connect(remote_ws_url) as remote_ws:
|
||||
|
||||
async def forward_local_to_remote():
|
||||
async for msg in local_ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await remote_ws.send_str(msg.data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await remote_ws.send_bytes(msg.data)
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
|
||||
return
|
||||
|
||||
async def forward_remote_to_local():
|
||||
async for msg in remote_ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await local_ws.send_str(msg.data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await local_ws.send_bytes(msg.data)
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
|
||||
return
|
||||
|
||||
# Run both directions concurrently. When either side closes,
|
||||
# cancel the other to prevent hanging.
|
||||
task_l2r = asyncio.create_task(forward_local_to_remote())
|
||||
task_r2l = asyncio.create_task(forward_remote_to_local())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{task_l2r, task_r2l}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
# Ensure both sides are closed
|
||||
if not remote_ws.closed:
|
||||
await remote_ws.close()
|
||||
if not local_ws.closed:
|
||||
await local_ws.close()
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] WebSocket proxy error for %s: %s", request.path, exc)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
return local_ws
|
||||
|
||||
|
||||
async def _proxy_http(request: web.Request) -> web.Response:
|
||||
"""Forward an HTTP request to the remote LoRA Manager and return its response."""
|
||||
remote_url = f"{remote_config.remote_url}{request.path}"
|
||||
if request.query_string:
|
||||
remote_url += f"?{request.query_string}"
|
||||
|
||||
# Read the request body (if any)
|
||||
body = await request.read() if request.can_read_body else None
|
||||
|
||||
# Filter hop-by-hop headers
|
||||
headers = {}
|
||||
skip = {"host", "transfer-encoding", "connection", "keep-alive", "upgrade"}
|
||||
for k, v in request.headers.items():
|
||||
if k.lower() not in skip:
|
||||
headers[k] = v
|
||||
|
||||
session = await _get_proxy_session()
|
||||
try:
|
||||
async with session.request(
|
||||
method=request.method,
|
||||
url=remote_url,
|
||||
headers=headers,
|
||||
data=body,
|
||||
) as resp:
|
||||
resp_body = await resp.read()
|
||||
resp_headers = {}
|
||||
for k, v in resp.headers.items():
|
||||
if k.lower() not in ("transfer-encoding", "content-encoding", "content-length"):
|
||||
resp_headers[k] = v
|
||||
return web.Response(
|
||||
status=resp.status,
|
||||
body=resp_body,
|
||||
headers=resp_headers,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("[LM-Remote] Proxy error for %s %s: %s", request.method, request.path, exc)
|
||||
return web.json_response(
|
||||
{"error": f"Remote LoRA Manager unavailable: {exc}"},
|
||||
status=502,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@web.middleware
|
||||
async def lm_remote_proxy_middleware(request: web.Request, handler):
|
||||
"""aiohttp middleware that intercepts LoRA Manager requests."""
|
||||
if not remote_config.is_configured:
|
||||
return await handler(request)
|
||||
|
||||
path = request.path
|
||||
|
||||
# Routes that use send_sync must NOT be proxied — let the local
|
||||
# original LoRA Manager handle them so events reach the local browser.
|
||||
if path in _SEND_SYNC_SKIP_ROUTES:
|
||||
return await handler(request)
|
||||
|
||||
# WebSocket routes
|
||||
if _is_ws_route(path):
|
||||
return await _proxy_ws(request)
|
||||
|
||||
# Regular proxy routes
|
||||
if _should_proxy(path):
|
||||
return await _proxy_http(request)
|
||||
|
||||
# Not a LoRA Manager route — fall through
|
||||
return await handler(request)
|
||||
|
||||
|
||||
async def _cleanup_proxy_session(app) -> None:
|
||||
"""Shutdown hook to close the shared proxy session."""
|
||||
global _proxy_session
|
||||
if _proxy_session and not _proxy_session.closed:
|
||||
await _proxy_session.close()
|
||||
_proxy_session = None
|
||||
|
||||
|
||||
def register_proxy(app) -> None:
|
||||
"""Insert the proxy middleware into the aiohttp app."""
|
||||
if not remote_config.is_configured:
|
||||
logger.warning("[LM-Remote] No remote_url configured — proxy disabled")
|
||||
return
|
||||
|
||||
# Insert at position 0 so we run before the original LoRA Manager routes
|
||||
app.middlewares.insert(0, lm_remote_proxy_middleware)
|
||||
app.on_shutdown.append(_cleanup_proxy_session)
|
||||
logger.info("[LM-Remote] Proxy routes registered -> %s", remote_config.remote_url)
|
||||
214
remote_client.py
Normal file
214
remote_client.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""HTTP client for the remote LoRA Manager instance."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .config import remote_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache TTL in seconds — how long before we re-fetch the full LoRA list
|
||||
_CACHE_TTL = 60
|
||||
|
||||
|
||||
class RemoteLoraClient:
|
||||
"""Singleton HTTP client that talks to the remote LoRA Manager.
|
||||
|
||||
Uses the actual LoRA Manager REST API endpoints:
|
||||
- ``GET /api/lm/loras/list?page_size=9999`` — paginated LoRA list
|
||||
- ``GET /api/lm/loras/get-trigger-words?name=X`` — trigger words
|
||||
- ``POST /api/lm/loras/random-sample`` — random LoRA selection
|
||||
- ``POST /api/lm/loras/cycler-list`` — sorted LoRA list for cycler
|
||||
|
||||
A short-lived in-memory cache avoids redundant calls to the list endpoint
|
||||
during a single workflow execution (which may resolve many LoRAs at once).
|
||||
"""
|
||||
|
||||
_instance: RemoteLoraClient | None = None
|
||||
_session: aiohttp.ClientSession | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._lora_cache: list[dict] = []
|
||||
self._lora_cache_ts: float = 0
|
||||
self._checkpoint_cache: list[dict] = []
|
||||
self._checkpoint_cache_ts: float = 0
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> RemoteLoraClient:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=remote_config.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core HTTP helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_json(self, path: str, params: dict | None = None) -> Any:
|
||||
url = f"{remote_config.remote_url}{path}"
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
async def _post_json(self, path: str, json_body: dict | None = None) -> Any:
|
||||
url = f"{remote_config.remote_url}{path}"
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=json_body) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cached list helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_lora_list_cached(self) -> list[dict]:
|
||||
"""Return the full LoRA list, using a short-lived cache."""
|
||||
now = time.monotonic()
|
||||
if self._lora_cache and (now - self._lora_cache_ts) < _CACHE_TTL:
|
||||
return self._lora_cache
|
||||
|
||||
try:
|
||||
data = await self._get_json(
|
||||
"/api/lm/loras/list", params={"page_size": "9999"}
|
||||
)
|
||||
self._lora_cache = data.get("items", [])
|
||||
self._lora_cache_ts = now
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] Failed to fetch LoRA list: %s", exc)
|
||||
# Return stale cache on error, or empty list
|
||||
return self._lora_cache
|
||||
|
||||
async def _get_checkpoint_list_cached(self) -> list[dict]:
|
||||
"""Return the full checkpoint list, using a short-lived cache."""
|
||||
now = time.monotonic()
|
||||
if self._checkpoint_cache and (now - self._checkpoint_cache_ts) < _CACHE_TTL:
|
||||
return self._checkpoint_cache
|
||||
|
||||
try:
|
||||
data = await self._get_json(
|
||||
"/api/lm/checkpoints/list", params={"page_size": "9999"}
|
||||
)
|
||||
self._checkpoint_cache = data.get("items", [])
|
||||
self._checkpoint_cache_ts = now
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] Failed to fetch checkpoint list: %s", exc)
|
||||
return self._checkpoint_cache
|
||||
|
||||
def _find_item_by_name(self, items: list[dict], name: str) -> dict | None:
|
||||
"""Find an item in a list by file_name."""
|
||||
for item in items:
|
||||
if item.get("file_name") == name:
|
||||
return item
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_lora_info(self, lora_name: str) -> tuple[str, list[str]]:
|
||||
"""Return (relative_path, trigger_words) for a LoRA by display name.
|
||||
|
||||
Uses the cached ``/api/lm/loras/list`` data. Falls back to the
|
||||
per-LoRA ``get-trigger-words`` endpoint if the list lookup fails.
|
||||
"""
|
||||
import posixpath
|
||||
|
||||
try:
|
||||
items = await self._get_lora_list_cached()
|
||||
item = self._find_item_by_name(items, lora_name)
|
||||
|
||||
if item:
|
||||
file_path = item.get("file_path", "")
|
||||
file_path = remote_config.map_path(file_path)
|
||||
|
||||
# file_path is the absolute path (forward-slashed) from
|
||||
# the remote. We need a relative path that the local
|
||||
# folder_paths.get_full_path("loras", ...) can resolve.
|
||||
#
|
||||
# The ``folder`` field gives the subfolder within the
|
||||
# model root (e.g. "anime" or "anime/characters").
|
||||
# The basename of file_path has the extension.
|
||||
#
|
||||
# Example: file_path="/mnt/loras/anime/test.safetensors"
|
||||
# folder="anime"
|
||||
# -> basename="test.safetensors"
|
||||
# -> relative="anime/test.safetensors"
|
||||
folder = item.get("folder", "")
|
||||
basename = posixpath.basename(file_path) # "test.safetensors"
|
||||
|
||||
if folder:
|
||||
relative = f"{folder}/{basename}"
|
||||
else:
|
||||
relative = basename
|
||||
|
||||
civitai = item.get("civitai") or {}
|
||||
trigger_words = civitai.get("trainedWords", []) if civitai else []
|
||||
return relative, trigger_words
|
||||
|
||||
# Fallback: try the specific trigger-words endpoint
|
||||
tw_data = await self._get_json(
|
||||
"/api/lm/loras/get-trigger-words",
|
||||
params={"name": lora_name},
|
||||
)
|
||||
trigger_words = tw_data.get("trigger_words", [])
|
||||
return lora_name, trigger_words
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] get_lora_info(%s) failed: %s", lora_name, exc)
|
||||
return lora_name, []
|
||||
|
||||
async def get_lora_hash(self, lora_name: str) -> str | None:
|
||||
"""Return the SHA-256 hash for a LoRA by display name."""
|
||||
try:
|
||||
items = await self._get_lora_list_cached()
|
||||
item = self._find_item_by_name(items, lora_name)
|
||||
if item:
|
||||
return item.get("sha256") or item.get("hash")
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] get_lora_hash(%s) failed: %s", lora_name, exc)
|
||||
return None
|
||||
|
||||
async def get_checkpoint_hash(self, checkpoint_name: str) -> str | None:
|
||||
"""Return the SHA-256 hash for a checkpoint by display name."""
|
||||
try:
|
||||
items = await self._get_checkpoint_list_cached()
|
||||
item = self._find_item_by_name(items, checkpoint_name)
|
||||
if item:
|
||||
return item.get("sha256") or item.get("hash")
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] get_checkpoint_hash(%s) failed: %s", checkpoint_name, exc)
|
||||
return None
|
||||
|
||||
async def get_random_loras(self, **kwargs) -> list[dict]:
|
||||
"""Ask the remote to generate random LoRAs (for Randomizer node)."""
|
||||
try:
|
||||
result = await self._post_json("/api/lm/loras/random-sample", json_body=kwargs)
|
||||
return result if isinstance(result, list) else result.get("loras", [])
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] get_random_loras failed: %s", exc)
|
||||
return []
|
||||
|
||||
async def get_cycler_list(self, **kwargs) -> list[dict]:
|
||||
"""Ask the remote for a sorted LoRA list (for Cycler node)."""
|
||||
try:
|
||||
result = await self._post_json("/api/lm/loras/cycler-list", json_body=kwargs)
|
||||
return result if isinstance(result, list) else result.get("loras", [])
|
||||
except Exception as exc:
|
||||
logger.warning("[LM-Remote] get_cycler_list failed: %s", exc)
|
||||
return []
|
||||
166
web/comfyui/lora_loader_remote.js
Normal file
166
web/comfyui/lora_loader_remote.js
Normal file
@@ -0,0 +1,166 @@
|
||||
/**
|
||||
* JS shim for Remote Lora Loader / Remote Lora Text Loader nodes.
|
||||
*
|
||||
* Re-uses all widget infrastructure from the original ComfyUI-Lora-Manager;
|
||||
* the only difference is matching on the remote node NAMEs.
|
||||
*/
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import {
|
||||
collectActiveLorasFromChain,
|
||||
updateConnectedTriggerWords,
|
||||
chainCallback,
|
||||
mergeLoras,
|
||||
getAllGraphNodes,
|
||||
getNodeFromGraph,
|
||||
} from "/extensions/ComfyUI-Lora-Manager/utils.js";
|
||||
import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js";
|
||||
import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js";
|
||||
import { applySelectionHighlight } from "/extensions/ComfyUI-Lora-Manager/trigger_word_highlight.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.LoraLoaderRemote",
|
||||
|
||||
setup() {
|
||||
api.addEventListener("lora_code_update", (event) => {
|
||||
this.handleLoraCodeUpdate(event.detail || {});
|
||||
});
|
||||
},
|
||||
|
||||
handleLoraCodeUpdate(message) {
|
||||
const nodeId = message?.node_id ?? message?.id;
|
||||
const graphId = message?.graph_id;
|
||||
const loraCode = message?.lora_code ?? "";
|
||||
const mode = message?.mode ?? "append";
|
||||
const numericNodeId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||
|
||||
if (numericNodeId === -1) {
|
||||
const loraLoaderNodes = getAllGraphNodes(app.graph)
|
||||
.map(({ node }) => node)
|
||||
.filter((node) => node?.comfyClass === "Lora Loader (Remote, LoraManager)");
|
||||
|
||||
if (loraLoaderNodes.length > 0) {
|
||||
loraLoaderNodes.forEach((node) => {
|
||||
this.updateNodeLoraCode(node, loraCode, mode);
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const node = getNodeFromGraph(graphId, numericNodeId);
|
||||
if (
|
||||
!node ||
|
||||
(node.comfyClass !== "Lora Loader (Remote, LoraManager)" &&
|
||||
node.comfyClass !== "Lora Stacker (Remote, LoraManager)" &&
|
||||
node.comfyClass !== "WanVideo Lora Select (Remote, LoraManager)")
|
||||
) {
|
||||
return;
|
||||
}
|
||||
this.updateNodeLoraCode(node, loraCode, mode);
|
||||
},
|
||||
|
||||
updateNodeLoraCode(node, loraCode, mode) {
|
||||
const inputWidget = node.inputWidget;
|
||||
if (!inputWidget) return;
|
||||
|
||||
const currentValue = inputWidget.value || "";
|
||||
if (mode === "replace") {
|
||||
inputWidget.value = loraCode;
|
||||
} else {
|
||||
inputWidget.value = currentValue.trim()
|
||||
? `${currentValue.trim()} ${loraCode}`
|
||||
: loraCode;
|
||||
}
|
||||
|
||||
if (typeof inputWidget.callback === "function") {
|
||||
inputWidget.callback(inputWidget.value);
|
||||
}
|
||||
},
|
||||
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeType.comfyClass === "Lora Loader (Remote, LoraManager)") {
|
||||
chainCallback(nodeType.prototype, "onNodeCreated", function () {
|
||||
this.serialize_widgets = true;
|
||||
|
||||
this.addInput("clip", "CLIP", { shape: 7 });
|
||||
this.addInput("lora_stack", "LORA_STACK", { shape: 7 });
|
||||
|
||||
let isUpdating = false;
|
||||
let isSyncingInput = false;
|
||||
|
||||
const self = this;
|
||||
let _mode = this.mode;
|
||||
Object.defineProperty(this, "mode", {
|
||||
get() { return _mode; },
|
||||
set(value) {
|
||||
const oldValue = _mode;
|
||||
_mode = value;
|
||||
if (self.onModeChange) self.onModeChange(value, oldValue);
|
||||
},
|
||||
});
|
||||
|
||||
this.onModeChange = function (newMode) {
|
||||
const allActiveLoraNames = collectActiveLorasFromChain(self);
|
||||
updateConnectedTriggerWords(self, allActiveLoraNames);
|
||||
};
|
||||
|
||||
const inputWidget = this.widgets[0];
|
||||
this.inputWidget = inputWidget;
|
||||
|
||||
const scheduleInputSync = debounce((lorasValue) => {
|
||||
if (isSyncingInput) return;
|
||||
isSyncingInput = true;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const nextText = applyLoraValuesToText(inputWidget.value, lorasValue);
|
||||
if (inputWidget.value !== nextText) inputWidget.value = nextText;
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
isSyncingInput = false;
|
||||
}
|
||||
});
|
||||
|
||||
this.lorasWidget = addLorasWidget(
|
||||
this, "loras",
|
||||
{ onSelectionChange: (selection) => applySelectionHighlight(this, selection) },
|
||||
(value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const allActiveLoraNames = collectActiveLorasFromChain(this);
|
||||
updateConnectedTriggerWords(this, allActiveLoraNames);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
scheduleInputSync(value);
|
||||
}
|
||||
).widget;
|
||||
|
||||
inputWidget.callback = (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const currentLoras = this.lorasWidget.value || [];
|
||||
const mergedLoras = mergeLoras(value, currentLoras);
|
||||
this.lorasWidget.value = mergedLoras;
|
||||
const allActiveLoraNames = collectActiveLorasFromChain(this);
|
||||
updateConnectedTriggerWords(this, allActiveLoraNames);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
async loadedGraphNode(node) {
|
||||
if (node.comfyClass === "Lora Loader (Remote, LoraManager)") {
|
||||
let existingLoras = [];
|
||||
if (node.widgets_values && node.widgets_values.length > 0) {
|
||||
existingLoras = node.widgets_values[1] || [];
|
||||
}
|
||||
const mergedLoras = mergeLoras(node.widgets[0].value, existingLoras);
|
||||
node.lorasWidget.value = mergedLoras;
|
||||
}
|
||||
},
|
||||
});
|
||||
97
web/comfyui/lora_stacker_remote.js
Normal file
97
web/comfyui/lora_stacker_remote.js
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* JS shim for Remote Lora Stacker node.
|
||||
*/
|
||||
import { app } from "../../scripts/app.js";
|
||||
import {
|
||||
getActiveLorasFromNode,
|
||||
updateConnectedTriggerWords,
|
||||
updateDownstreamLoaders,
|
||||
chainCallback,
|
||||
mergeLoras,
|
||||
} from "/extensions/ComfyUI-Lora-Manager/utils.js";
|
||||
import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js";
|
||||
import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js";
|
||||
import { applySelectionHighlight } from "/extensions/ComfyUI-Lora-Manager/trigger_word_highlight.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.LoraStackerRemote",
|
||||
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeType.comfyClass === "Lora Stacker (Remote, LoraManager)") {
|
||||
chainCallback(nodeType.prototype, "onNodeCreated", async function () {
|
||||
this.serialize_widgets = true;
|
||||
this.addInput("lora_stack", "LORA_STACK", { shape: 7 });
|
||||
|
||||
let isUpdating = false;
|
||||
let isSyncingInput = false;
|
||||
|
||||
const inputWidget = this.widgets[0];
|
||||
this.inputWidget = inputWidget;
|
||||
|
||||
const scheduleInputSync = debounce((lorasValue) => {
|
||||
if (isSyncingInput) return;
|
||||
isSyncingInput = true;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const nextText = applyLoraValuesToText(inputWidget.value, lorasValue);
|
||||
if (inputWidget.value !== nextText) inputWidget.value = nextText;
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
isSyncingInput = false;
|
||||
}
|
||||
});
|
||||
|
||||
const result = addLorasWidget(
|
||||
this, "loras",
|
||||
{ onSelectionChange: (selection) => applySelectionHighlight(this, selection) },
|
||||
(value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
|
||||
const activeLoraNames = new Set();
|
||||
if (isNodeActive) {
|
||||
value.forEach((lora) => { if (lora.active) activeLoraNames.add(lora.name); });
|
||||
}
|
||||
updateConnectedTriggerWords(this, activeLoraNames);
|
||||
updateDownstreamLoaders(this);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
scheduleInputSync(value);
|
||||
}
|
||||
);
|
||||
|
||||
this.lorasWidget = result.widget;
|
||||
|
||||
inputWidget.callback = (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const currentLoras = this.lorasWidget?.value || [];
|
||||
const mergedLoras = mergeLoras(value, currentLoras);
|
||||
if (this.lorasWidget) this.lorasWidget.value = mergedLoras;
|
||||
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
|
||||
const activeLoraNames = isNodeActive ? getActiveLorasFromNode(this) : new Set();
|
||||
updateConnectedTriggerWords(this, activeLoraNames);
|
||||
updateDownstreamLoaders(this);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
async loadedGraphNode(node) {
|
||||
if (node.comfyClass === "Lora Stacker (Remote, LoraManager)") {
|
||||
let existingLoras = [];
|
||||
if (node.widgets_values && node.widgets_values.length > 0) {
|
||||
existingLoras = node.widgets_values[1] || [];
|
||||
}
|
||||
const inputWidget = node.inputWidget || node.widgets[0];
|
||||
const mergedLoras = mergeLoras(inputWidget.value, existingLoras);
|
||||
node.lorasWidget.value = mergedLoras;
|
||||
}
|
||||
},
|
||||
});
|
||||
89
web/comfyui/wanvideo_remote.js
Normal file
89
web/comfyui/wanvideo_remote.js
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* JS shim for Remote WanVideo Lora Select node.
|
||||
*/
|
||||
import { app } from "../../scripts/app.js";
|
||||
import {
|
||||
getActiveLorasFromNode,
|
||||
updateConnectedTriggerWords,
|
||||
chainCallback,
|
||||
mergeLoras,
|
||||
} from "/extensions/ComfyUI-Lora-Manager/utils.js";
|
||||
import { addLorasWidget } from "/extensions/ComfyUI-Lora-Manager/loras_widget.js";
|
||||
import { applyLoraValuesToText, debounce } from "/extensions/ComfyUI-Lora-Manager/lora_syntax_utils.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.WanVideoLoraSelectRemote",
|
||||
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeType.comfyClass === "WanVideo Lora Select (Remote, LoraManager)") {
|
||||
chainCallback(nodeType.prototype, "onNodeCreated", async function () {
|
||||
this.serialize_widgets = true;
|
||||
|
||||
this.addInput("prev_lora", "WANVIDLORA", { shape: 7 });
|
||||
this.addInput("blocks", "SELECTEDBLOCKS", { shape: 7 });
|
||||
|
||||
let isUpdating = false;
|
||||
let isSyncingInput = false;
|
||||
|
||||
// text widget is at index 2 (after low_mem_load, merge_loras)
|
||||
const inputWidget = this.widgets[2];
|
||||
this.inputWidget = inputWidget;
|
||||
|
||||
const scheduleInputSync = debounce((lorasValue) => {
|
||||
if (isSyncingInput) return;
|
||||
isSyncingInput = true;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const nextText = applyLoraValuesToText(inputWidget.value, lorasValue);
|
||||
if (inputWidget.value !== nextText) inputWidget.value = nextText;
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
isSyncingInput = false;
|
||||
}
|
||||
});
|
||||
|
||||
const result = addLorasWidget(this, "loras", {}, (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const activeLoraNames = new Set();
|
||||
value.forEach((lora) => { if (lora.active) activeLoraNames.add(lora.name); });
|
||||
updateConnectedTriggerWords(this, activeLoraNames);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
scheduleInputSync(value);
|
||||
});
|
||||
|
||||
this.lorasWidget = result.widget;
|
||||
|
||||
inputWidget.callback = (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
try {
|
||||
const currentLoras = this.lorasWidget?.value || [];
|
||||
const mergedLoras = mergeLoras(value, currentLoras);
|
||||
if (this.lorasWidget) this.lorasWidget.value = mergedLoras;
|
||||
const activeLoraNames = getActiveLorasFromNode(this);
|
||||
updateConnectedTriggerWords(this, activeLoraNames);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
async loadedGraphNode(node) {
|
||||
if (node.comfyClass === "WanVideo Lora Select (Remote, LoraManager)") {
|
||||
let existingLoras = [];
|
||||
if (node.widgets_values && node.widgets_values.length > 0) {
|
||||
// 0=low_mem_load, 1=merge_loras, 2=text, 3=loras
|
||||
existingLoras = node.widgets_values[3] || [];
|
||||
}
|
||||
const inputWidget = node.inputWidget || node.widgets[2];
|
||||
const mergedLoras = mergeLoras(inputWidget.value, existingLoras);
|
||||
node.lorasWidget.value = mergedLoras;
|
||||
}
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user