diff --git a/image_preview.py b/image_preview.py index e081452..67f190e 100644 --- a/image_preview.py +++ b/image_preview.py @@ -13,6 +13,55 @@ import node_helpers from comfy.cli_args import args from comfy_execution.graph_utils import ExecutionBlocker +try: + from server import PromptServer + from aiohttp import web +except ImportError: + PromptServer = None + +CHANNELS_FILE = os.path.join(os.path.dirname(__file__), "channels.json") + + +def _read_channels(): + if os.path.exists(CHANNELS_FILE): + try: + with open(CHANNELS_FILE, "r") as f: + return json.load(f) + except (json.JSONDecodeError, ValueError): + return {} + return {} + + +def _write_channels(data): + with open(CHANNELS_FILE, "w") as f: + json.dump(data, f, indent=2) + + +if PromptServer is not None: + @PromptServer.instance.routes.post("/jdl/channel/send") + async def channel_send(request): + body = await request.json() + channel = body.get("channel", "default") + filename = body.get("filename", "") + if not filename: + return web.json_response({"error": "filename required"}, status=400) + channels = _read_channels() + channels[channel] = filename + _write_channels(channels) + return web.json_response({"ok": True, "channel": channel, "filename": filename}) + + @PromptServer.instance.routes.get("/jdl/channel/receive") + async def channel_receive(request): + channel = request.query.get("channel", "default") + channels = _read_channels() + filename = channels.get(channel, "") + return web.json_response({"channel": channel, "filename": filename}) + + @PromptServer.instance.routes.get("/jdl/channel/list") + async def channel_list(request): + channels = _read_channels() + return web.json_response({"channels": list(channels.keys())}) + class JDL_PreviewToLoad: """Previews an image and saves a copy to input/ for use by LoadImage nodes.""" @@ -120,6 +169,54 @@ class JDL_PreviewToLoad: return {"ui": {"images": results, "input_filename": [input_filename]}} +def _load_image_from_path(image_path): + """Shared helper: load an image file and return (IMAGE tensor, MASK tensor).""" + img = node_helpers.pillow(Image.open, image_path) + + output_images = [] + output_masks = [] + w, h = None, None + + for i in ImageSequence.Iterator(img): + i = node_helpers.pillow(ImageOps.exif_transpose, i) + + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + frame = i.convert("RGB") + + if len(output_images) == 0: + w = frame.size[0] + h = frame.size[1] + + if frame.size[0] != w or frame.size[1] != h: + continue + + frame_np = np.array(frame).astype(np.float32) / 255.0 + frame_tensor = torch.from_numpy(frame_np)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + elif i.mode == 'P' and 'transparency' in i.info: + mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") + output_images.append(frame_tensor) + output_masks.append(mask.unsqueeze(0)) + + if img.format == "MPO": + break + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) + else: + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) + + class JDL_LoadImage: """Load an image from the input directory with an active switch to skip downstream execution.""" @@ -144,50 +241,7 @@ class JDL_LoadImage: return (ExecutionBlocker(None), ExecutionBlocker(None)) image_path = folder_paths.get_annotated_filepath(image) - img = node_helpers.pillow(Image.open, image_path) - - output_images = [] - output_masks = [] - w, h = None, None - - for i in ImageSequence.Iterator(img): - i = node_helpers.pillow(ImageOps.exif_transpose, i) - - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - frame = i.convert("RGB") - - if len(output_images) == 0: - w = frame.size[0] - h = frame.size[1] - - if frame.size[0] != w or frame.size[1] != h: - continue - - frame_np = np.array(frame).astype(np.float32) / 255.0 - frame_tensor = torch.from_numpy(frame_np)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - else: - mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") - output_images.append(frame_tensor) - output_masks.append(mask.unsqueeze(0)) - - if img.format == "MPO": - break - - if len(output_images) > 1: - output_image = torch.cat(output_images, dim=0) - output_mask = torch.cat(output_masks, dim=0) - else: - output_image = output_images[0] - output_mask = output_masks[0] - - return (output_image, output_mask) + return _load_image_from_path(image_path) @classmethod def IS_CHANGED(s, image, active): @@ -208,12 +262,61 @@ class JDL_LoadImage: return True +class JDL_ImageReceiver: + """Load an image from a named channel (cross-workflow image passing).""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "channel": ("STRING", {"default": "default"}), + "active": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("IMAGE", "MASK", "STRING") + RETURN_NAMES = ("image", "mask", "filename") + FUNCTION = "receive" + CATEGORY = "utils/image" + + def receive(self, channel, active): + if not active: + return (ExecutionBlocker(None), ExecutionBlocker(None), ExecutionBlocker(None)) + + channels = _read_channels() + filename = channels.get(channel, "") + if not filename: + return (ExecutionBlocker(None), ExecutionBlocker(None), ExecutionBlocker(None)) + + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, filename) + if not os.path.isfile(image_path): + return (ExecutionBlocker(None), ExecutionBlocker(None), ExecutionBlocker(None)) + + output_image, output_mask = _load_image_from_path(image_path) + return (output_image, output_mask, filename) + + @classmethod + def IS_CHANGED(s, channel, active): + if not active: + return "inactive" + channels = _read_channels() + filename = channels.get(channel, "") + return hashlib.sha256(filename.encode()).hexdigest() + + @classmethod + def VALIDATE_INPUTS(s, channel, active): + return True + + NODE_CLASS_MAPPINGS = { "JDL_PreviewToLoad": JDL_PreviewToLoad, "JDL_LoadImage": JDL_LoadImage, + "JDL_ImageReceiver": JDL_ImageReceiver, } NODE_DISPLAY_NAME_MAPPINGS = { "JDL_PreviewToLoad": "Preview to Load Image", "JDL_LoadImage": "Load Image (Active Switch)", + "JDL_ImageReceiver": "Image Receiver (Channel)", } diff --git a/web/image_preview.js b/web/image_preview.js index 7443d6c..9a7d742 100644 --- a/web/image_preview.js +++ b/web/image_preview.js @@ -53,6 +53,34 @@ app.registerExtension({ app.graph.setDirtyCanvas(true, true); }); + this.addWidget("text", "channel", "default", () => {}); + + this.addWidget("button", "Send to Channel", null, () => { + const channelWidget = this.widgets?.find(w => w.name === "channel"); + const channel = channelWidget?.value || "default"; + + const filename = this.last_input_filename; + if (!filename) { + console.warn("[PreviewToLoad] No filename available. Run the workflow first."); + return; + } + + fetch("/jdl/channel/send", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ channel, filename }), + }) + .then(r => r.json()) + .then(data => { + if (data.ok) { + console.log(`[PreviewToLoad] Sent "${filename}" to channel "${channel}"`); + } else { + console.warn("[PreviewToLoad] Channel send failed:", data.error); + } + }) + .catch(err => console.error("[PreviewToLoad] Channel send error:", err)); + }); + this.setSize(this.computeSize()); };