"""Latent splitter nodes for ComfyUI. Splits a latent along its spatial dimensions into halves (left/right or top/bottom) or quarters. The latent tensor has shape ``[B, C, H, W]`` where H/W are 1/8 of the image size, so this is model-agnostic: it only touches the H/W axes and works for SD/SDXL (4 channels) and Flux / FLUX.1 Krea (16 channels) alike. Outputs are filled in row-major (reading) order: left / right -> latent_1 = left, latent_2 = right top / bottom -> latent_1 = top, latent_2 = bottom quad (4) -> latent_1 = top-left, latent_2 = top-right, latent_3 = bottom-left, latent_4 = bottom-right Unused slots (3 and 4 in the two-way modes) return ``None``; leave them unconnected. """ SPLIT_MODES = ["left / right", "top / bottom", "quad (4)"] def _halve(t, axis): """Split a tensor in two along ``axis`` at its midpoint (floor division). Returns views; the caller makes them contiguous when wrapping into a latent. """ n = t.shape[axis] mid = n // 2 return t.narrow(axis, 0, mid), t.narrow(axis, mid, n - mid) def _wrap(latent, samples, mask): """Copy the latent dict, swapping in a sub-tile of samples (and its mask).""" out = {k: v for k, v in latent.items() if k not in ("samples", "noise_mask")} out["samples"] = samples.contiguous() if mask is not None: out["noise_mask"] = mask.contiguous() return out class LatentSplit: """Split one latent into 2 (left/right or top/bottom) or 4 tiles.""" @classmethod def INPUT_TYPES(cls): return { "required": { "samples": ("LATENT", {"tooltip": "Latent to split."}), "mode": ( SPLIT_MODES, { "default": "quad (4)", "tooltip": "left/right splits the width, top/bottom " "splits the height, quad splits both into 4 tiles.", }, ), } } RETURN_TYPES = ("LATENT", "LATENT", "LATENT", "LATENT") RETURN_NAMES = ("latent_1", "latent_2", "latent_3", "latent_4") OUTPUT_TOOLTIPS = ( "left / top / top-left tile", "right / bottom / top-right tile", "bottom-left tile (quad only, else None)", "bottom-right tile (quad only, else None)", ) FUNCTION = "split" CATEGORY = "latent" DESCRIPTION = ( "Split a latent into halves or quarters along its spatial axes.\n" "Model-agnostic (works with SD, SDXL, Flux / FLUX.1 Krea).\n\n" "Outputs fill in reading order:\n" " left/right -> 1=left, 2=right\n" " top/bottom -> 1=top, 2=bottom\n" " quad -> 1=TL, 2=TR, 3=BL, 4=BR\n" "Odd latent dimensions split with the extra row/column going to the " "second tile." ) def split(self, samples, mode): latent = samples s = latent["samples"] # [B, C, H, W] mask = latent.get("noise_mask") # optional [B, 1, H', W'] if mode == "left / right": s_parts = _halve(s, -1) m_parts = _halve(mask, -1) if mask is not None else (None, None) tiles = list(zip(s_parts, m_parts)) elif mode == "top / bottom": s_parts = _halve(s, -2) m_parts = _halve(mask, -2) if mask is not None else (None, None) tiles = list(zip(s_parts, m_parts)) else: # quad (4) top_s, bot_s = _halve(s, -2) tl, tr = _halve(top_s, -1) bl, br = _halve(bot_s, -1) s_parts = [tl, tr, bl, br] if mask is not None: top_m, bot_m = _halve(mask, -2) ml, mr = _halve(top_m, -1) mbl, mbr = _halve(bot_m, -1) m_parts = [ml, mr, mbl, mbr] else: m_parts = [None, None, None, None] tiles = list(zip(s_parts, m_parts)) outputs = [_wrap(latent, sub_s, sub_m) for sub_s, sub_m in tiles] while len(outputs) < 4: outputs.append(None) return tuple(outputs) NODE_CLASS_MAPPINGS = { "LatentSplit": LatentSplit, } NODE_DISPLAY_NAME_MAPPINGS = { "LatentSplit": "Latent Split (2 / 4)", }