Remote-aware LoRA Manager nodes that fetch metadata via HTTP from a remote Docker instance while loading LoRA files from local NFS/SMB mounts. Includes reverse-proxy middleware for transparent web UI access. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
382 lines
16 KiB
Python
382 lines
16 KiB
Python
"""Remote Save Image — uses remote API for hash lookups."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
|
|
import folder_paths # type: ignore
|
|
import numpy as np
|
|
from PIL import Image, PngImagePlugin
|
|
|
|
from ..remote_client import RemoteLoraClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
import piexif
|
|
except ImportError:
|
|
piexif = None
|
|
|
|
|
|
class SaveImageRemoteLM:
|
|
NAME = "Save Image (Remote, LoraManager)"
|
|
CATEGORY = "Lora Manager/utils"
|
|
DESCRIPTION = "Save images with embedded generation metadata (remote hash lookup)"
|
|
|
|
def __init__(self):
|
|
self.output_dir = folder_paths.get_output_directory()
|
|
self.type = "output"
|
|
self.prefix_append = ""
|
|
self.compress_level = 4
|
|
self.counter = 0
|
|
|
|
pattern_format = re.compile(r"(%[^%]+%)")
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"filename_prefix": ("STRING", {
|
|
"default": "ComfyUI",
|
|
"tooltip": "Base filename. Supports %seed%, %width%, %height%, %model%, etc.",
|
|
}),
|
|
"file_format": (["png", "jpeg", "webp"], {
|
|
"tooltip": "Image format to save as.",
|
|
}),
|
|
},
|
|
"optional": {
|
|
"lossless_webp": ("BOOLEAN", {"default": False}),
|
|
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
|
|
"embed_workflow": ("BOOLEAN", {"default": False}),
|
|
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
|
|
},
|
|
"hidden": {
|
|
"id": "UNIQUE_ID",
|
|
"prompt": "PROMPT",
|
|
"extra_pnginfo": "EXTRA_PNGINFO",
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("images",)
|
|
FUNCTION = "process_image"
|
|
OUTPUT_NODE = True
|
|
|
|
# ------------------------------------------------------------------
|
|
# Remote hash lookups
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run_async(self, coro):
|
|
"""Run an async coroutine from sync context."""
|
|
try:
|
|
asyncio.get_running_loop()
|
|
|
|
def _in_thread():
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
return loop.run_until_complete(coro)
|
|
finally:
|
|
loop.close()
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
return pool.submit(_in_thread).result()
|
|
except RuntimeError:
|
|
return asyncio.run(coro)
|
|
|
|
def get_lora_hash(self, lora_name):
|
|
client = RemoteLoraClient.get_instance()
|
|
return self._run_async(client.get_lora_hash(lora_name))
|
|
|
|
def get_checkpoint_hash(self, checkpoint_path):
|
|
if not checkpoint_path:
|
|
return None
|
|
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
|
client = RemoteLoraClient.get_instance()
|
|
return self._run_async(client.get_checkpoint_hash(checkpoint_name))
|
|
|
|
# ------------------------------------------------------------------
|
|
# Metadata formatting (identical to original)
|
|
# ------------------------------------------------------------------
|
|
|
|
def format_metadata(self, metadata_dict):
|
|
if not metadata_dict:
|
|
return ""
|
|
|
|
def add_param_if_not_none(param_list, label, value):
|
|
if value is not None:
|
|
param_list.append(f"{label}: {value}")
|
|
|
|
prompt = metadata_dict.get("prompt", "")
|
|
negative_prompt = metadata_dict.get("negative_prompt", "")
|
|
loras_text = metadata_dict.get("loras", "")
|
|
lora_hashes = {}
|
|
|
|
if loras_text:
|
|
prompt_with_loras = f"{prompt}\n{loras_text}"
|
|
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", loras_text)
|
|
for lora_name, _ in lora_matches:
|
|
hash_value = self.get_lora_hash(lora_name)
|
|
if hash_value:
|
|
lora_hashes[lora_name] = hash_value
|
|
else:
|
|
prompt_with_loras = prompt
|
|
|
|
metadata_parts = [prompt_with_loras]
|
|
|
|
if negative_prompt:
|
|
metadata_parts.append(f"Negative prompt: {negative_prompt}")
|
|
|
|
params = []
|
|
|
|
if "steps" in metadata_dict:
|
|
add_param_if_not_none(params, "Steps", metadata_dict.get("steps"))
|
|
|
|
sampler_name = None
|
|
scheduler_name = None
|
|
|
|
if "sampler" in metadata_dict:
|
|
sampler_mapping = {
|
|
"euler": "Euler", "euler_ancestral": "Euler a",
|
|
"dpm_2": "DPM2", "dpm_2_ancestral": "DPM2 a",
|
|
"heun": "Heun", "dpm_fast": "DPM fast",
|
|
"dpm_adaptive": "DPM adaptive", "lms": "LMS",
|
|
"dpmpp_2s_ancestral": "DPM++ 2S a", "dpmpp_sde": "DPM++ SDE",
|
|
"dpmpp_sde_gpu": "DPM++ SDE", "dpmpp_2m": "DPM++ 2M",
|
|
"dpmpp_2m_sde": "DPM++ 2M SDE", "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
|
|
"ddim": "DDIM",
|
|
}
|
|
sampler_name = sampler_mapping.get(metadata_dict["sampler"], metadata_dict["sampler"])
|
|
|
|
if "scheduler" in metadata_dict:
|
|
scheduler_mapping = {
|
|
"normal": "Simple", "karras": "Karras",
|
|
"exponential": "Exponential", "sgm_uniform": "SGM Uniform",
|
|
"sgm_quadratic": "SGM Quadratic",
|
|
}
|
|
scheduler_name = scheduler_mapping.get(metadata_dict["scheduler"], metadata_dict["scheduler"])
|
|
|
|
if sampler_name:
|
|
if scheduler_name:
|
|
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
|
else:
|
|
params.append(f"Sampler: {sampler_name}")
|
|
|
|
if "guidance" in metadata_dict:
|
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get("guidance"))
|
|
elif "cfg_scale" in metadata_dict:
|
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg_scale"))
|
|
elif "cfg" in metadata_dict:
|
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get("cfg"))
|
|
|
|
if "seed" in metadata_dict:
|
|
add_param_if_not_none(params, "Seed", metadata_dict.get("seed"))
|
|
|
|
if "size" in metadata_dict:
|
|
add_param_if_not_none(params, "Size", metadata_dict.get("size"))
|
|
|
|
if "checkpoint" in metadata_dict:
|
|
checkpoint = metadata_dict.get("checkpoint")
|
|
if checkpoint is not None:
|
|
model_hash = self.get_checkpoint_hash(checkpoint)
|
|
checkpoint_name = os.path.splitext(os.path.basename(checkpoint))[0]
|
|
if model_hash:
|
|
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
|
else:
|
|
params.append(f"Model: {checkpoint_name}")
|
|
|
|
if lora_hashes:
|
|
lora_hash_parts = [f"{n}: {h[:10]}" for n, h in lora_hashes.items()]
|
|
params.append(f'Lora hashes: "{", ".join(lora_hash_parts)}"')
|
|
|
|
metadata_parts.append(", ".join(params))
|
|
return "\n".join(metadata_parts)
|
|
|
|
def format_filename(self, filename, metadata_dict):
|
|
if not metadata_dict:
|
|
return filename
|
|
|
|
result = re.findall(self.pattern_format, filename)
|
|
for segment in result:
|
|
parts = segment.replace("%", "").split(":")
|
|
key = parts[0]
|
|
|
|
if key == "seed" and "seed" in metadata_dict:
|
|
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
|
|
elif key == "width" and "size" in metadata_dict:
|
|
size = metadata_dict.get("size", "x")
|
|
w = size.split("x")[0] if isinstance(size, str) else size[0]
|
|
filename = filename.replace(segment, str(w))
|
|
elif key == "height" and "size" in metadata_dict:
|
|
size = metadata_dict.get("size", "x")
|
|
h = size.split("x")[1] if isinstance(size, str) else size[1]
|
|
filename = filename.replace(segment, str(h))
|
|
elif key == "pprompt" and "prompt" in metadata_dict:
|
|
p = metadata_dict.get("prompt", "").replace("\n", " ")
|
|
if len(parts) >= 2:
|
|
p = p[: int(parts[1])]
|
|
filename = filename.replace(segment, p.strip())
|
|
elif key == "nprompt" and "negative_prompt" in metadata_dict:
|
|
p = metadata_dict.get("negative_prompt", "").replace("\n", " ")
|
|
if len(parts) >= 2:
|
|
p = p[: int(parts[1])]
|
|
filename = filename.replace(segment, p.strip())
|
|
elif key == "model":
|
|
model_value = metadata_dict.get("checkpoint")
|
|
if isinstance(model_value, (bytes, os.PathLike)):
|
|
model_value = str(model_value)
|
|
if not isinstance(model_value, str) or not model_value:
|
|
model = "model_unavailable"
|
|
else:
|
|
model = os.path.splitext(os.path.basename(model_value))[0]
|
|
if len(parts) >= 2:
|
|
model = model[: int(parts[1])]
|
|
filename = filename.replace(segment, model)
|
|
elif key == "date":
|
|
from datetime import datetime
|
|
now = datetime.now()
|
|
date_table = {
|
|
"yyyy": f"{now.year:04d}", "yy": f"{now.year % 100:02d}",
|
|
"MM": f"{now.month:02d}", "dd": f"{now.day:02d}",
|
|
"hh": f"{now.hour:02d}", "mm": f"{now.minute:02d}",
|
|
"ss": f"{now.second:02d}",
|
|
}
|
|
if len(parts) >= 2:
|
|
date_format = parts[1]
|
|
else:
|
|
date_format = "yyyyMMddhhmmss"
|
|
for k, v in date_table.items():
|
|
date_format = date_format.replace(k, v)
|
|
filename = filename.replace(segment, date_format)
|
|
|
|
return filename
|
|
|
|
# ------------------------------------------------------------------
|
|
# Image saving
|
|
# ------------------------------------------------------------------
|
|
|
|
def save_images(self, images, filename_prefix, file_format, id, prompt=None,
|
|
extra_pnginfo=None, lossless_webp=True, quality=100,
|
|
embed_workflow=False, add_counter_to_filename=True):
|
|
results = []
|
|
|
|
# Try to get metadata from the original LoRA Manager's collector.
|
|
# The package directory name varies across installs, so we search
|
|
# sys.modules for any loaded module whose path ends with the
|
|
# expected submodule.
|
|
metadata_dict = {}
|
|
try:
|
|
get_metadata = None
|
|
MetadataProcessor = None
|
|
|
|
import sys
|
|
for mod_name, mod in sys.modules.items():
|
|
if mod is None:
|
|
continue
|
|
if mod_name.endswith(".py.metadata_collector") and hasattr(mod, "get_metadata"):
|
|
get_metadata = mod.get_metadata
|
|
if mod_name.endswith(".py.metadata_collector.metadata_processor") and hasattr(mod, "MetadataProcessor"):
|
|
MetadataProcessor = mod.MetadataProcessor
|
|
if get_metadata and MetadataProcessor:
|
|
break
|
|
|
|
if get_metadata and MetadataProcessor:
|
|
raw_metadata = get_metadata()
|
|
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
|
else:
|
|
logger.debug("[LM-Remote] metadata_collector not found in loaded modules")
|
|
except Exception:
|
|
logger.debug("[LM-Remote] metadata_collector not available, saving without generation metadata")
|
|
|
|
metadata = self.format_metadata(metadata_dict)
|
|
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
|
|
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
|
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
|
)
|
|
|
|
if not os.path.exists(full_output_folder):
|
|
os.makedirs(full_output_folder, exist_ok=True)
|
|
|
|
for i, image in enumerate(images):
|
|
img = 255.0 * image.cpu().numpy()
|
|
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
|
|
|
|
base_filename = filename
|
|
if add_counter_to_filename:
|
|
current_counter = counter + i
|
|
base_filename += f"_{current_counter:05}_"
|
|
|
|
if file_format == "png":
|
|
file = base_filename + ".png"
|
|
save_kwargs = {"compress_level": self.compress_level}
|
|
pnginfo = PngImagePlugin.PngInfo()
|
|
elif file_format == "jpeg":
|
|
file = base_filename + ".jpg"
|
|
save_kwargs = {"quality": quality, "optimize": True}
|
|
elif file_format == "webp":
|
|
file = base_filename + ".webp"
|
|
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
|
|
|
|
file_path = os.path.join(full_output_folder, file)
|
|
|
|
try:
|
|
if file_format == "png":
|
|
if metadata:
|
|
pnginfo.add_text("parameters", metadata)
|
|
if embed_workflow and extra_pnginfo is not None:
|
|
pnginfo.add_text("workflow", json.dumps(extra_pnginfo["workflow"]))
|
|
save_kwargs["pnginfo"] = pnginfo
|
|
img.save(file_path, format="PNG", **save_kwargs)
|
|
elif file_format == "jpeg" and piexif:
|
|
if metadata:
|
|
try:
|
|
exif_dict = {"Exif": {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}}
|
|
save_kwargs["exif"] = piexif.dump(exif_dict)
|
|
except Exception as e:
|
|
logger.error("Error adding EXIF data: %s", e)
|
|
img.save(file_path, format="JPEG", **save_kwargs)
|
|
elif file_format == "webp" and piexif:
|
|
try:
|
|
exif_dict = {}
|
|
if metadata:
|
|
exif_dict["Exif"] = {piexif.ExifIFD.UserComment: b"UNICODE\0" + metadata.encode("utf-16be")}
|
|
if embed_workflow and extra_pnginfo is not None:
|
|
exif_dict["0th"] = {piexif.ImageIFD.ImageDescription: "Workflow:" + json.dumps(extra_pnginfo["workflow"])}
|
|
save_kwargs["exif"] = piexif.dump(exif_dict)
|
|
except Exception as e:
|
|
logger.error("Error adding EXIF data: %s", e)
|
|
img.save(file_path, format="WEBP", **save_kwargs)
|
|
else:
|
|
img.save(file_path)
|
|
|
|
results.append({"filename": file, "subfolder": subfolder, "type": self.type})
|
|
except Exception as e:
|
|
logger.error("Error saving image: %s", e)
|
|
|
|
return results
|
|
|
|
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png",
|
|
prompt=None, extra_pnginfo=None, lossless_webp=True, quality=100,
|
|
embed_workflow=False, add_counter_to_filename=True):
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
|
|
if isinstance(images, (list, np.ndarray)):
|
|
pass
|
|
else:
|
|
if len(images.shape) == 3:
|
|
images = [images]
|
|
else:
|
|
images = [img for img in images]
|
|
|
|
self.save_images(
|
|
images, filename_prefix, file_format, id, prompt, extra_pnginfo,
|
|
lossless_webp, quality, embed_workflow, add_counter_to_filename,
|
|
)
|
|
return (images,)
|