Files
Ethanfel 291a0a1f1c Add Latent Split node (2/4 spatial split) for ComfyUI
Splits a [B,C,H,W] latent into left/right or top/bottom halves, or four
quadrants. Model-agnostic (SD, SDXL, Flux / FLUX.1 Krea); preserves batch
dimension and noise_mask.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 13:25:47 +02:00

122 lines
4.2 KiB
Python

"""Latent splitter nodes for ComfyUI.
Splits a latent along its spatial dimensions into halves (left/right or
top/bottom) or quarters. The latent tensor has shape ``[B, C, H, W]`` where H/W
are 1/8 of the image size, so this is model-agnostic: it only touches the H/W
axes and works for SD/SDXL (4 channels) and Flux / FLUX.1 Krea (16 channels)
alike.
Outputs are filled in row-major (reading) order:
left / right -> latent_1 = left, latent_2 = right
top / bottom -> latent_1 = top, latent_2 = bottom
quad (4) -> latent_1 = top-left, latent_2 = top-right,
latent_3 = bottom-left, latent_4 = bottom-right
Unused slots (3 and 4 in the two-way modes) return ``None``; leave them
unconnected.
"""
SPLIT_MODES = ["left / right", "top / bottom", "quad (4)"]
def _halve(t, axis):
"""Split a tensor in two along ``axis`` at its midpoint (floor division).
Returns views; the caller makes them contiguous when wrapping into a latent.
"""
n = t.shape[axis]
mid = n // 2
return t.narrow(axis, 0, mid), t.narrow(axis, mid, n - mid)
def _wrap(latent, samples, mask):
"""Copy the latent dict, swapping in a sub-tile of samples (and its mask)."""
out = {k: v for k, v in latent.items() if k not in ("samples", "noise_mask")}
out["samples"] = samples.contiguous()
if mask is not None:
out["noise_mask"] = mask.contiguous()
return out
class LatentSplit:
"""Split one latent into 2 (left/right or top/bottom) or 4 tiles."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"samples": ("LATENT", {"tooltip": "Latent to split."}),
"mode": (
SPLIT_MODES,
{
"default": "quad (4)",
"tooltip": "left/right splits the width, top/bottom "
"splits the height, quad splits both into 4 tiles.",
},
),
}
}
RETURN_TYPES = ("LATENT", "LATENT", "LATENT", "LATENT")
RETURN_NAMES = ("latent_1", "latent_2", "latent_3", "latent_4")
OUTPUT_TOOLTIPS = (
"left / top / top-left tile",
"right / bottom / top-right tile",
"bottom-left tile (quad only, else None)",
"bottom-right tile (quad only, else None)",
)
FUNCTION = "split"
CATEGORY = "latent"
DESCRIPTION = (
"Split a latent into halves or quarters along its spatial axes.\n"
"Model-agnostic (works with SD, SDXL, Flux / FLUX.1 Krea).\n\n"
"Outputs fill in reading order:\n"
" left/right -> 1=left, 2=right\n"
" top/bottom -> 1=top, 2=bottom\n"
" quad -> 1=TL, 2=TR, 3=BL, 4=BR\n"
"Odd latent dimensions split with the extra row/column going to the "
"second tile."
)
def split(self, samples, mode):
latent = samples
s = latent["samples"] # [B, C, H, W]
mask = latent.get("noise_mask") # optional [B, 1, H', W']
if mode == "left / right":
s_parts = _halve(s, -1)
m_parts = _halve(mask, -1) if mask is not None else (None, None)
tiles = list(zip(s_parts, m_parts))
elif mode == "top / bottom":
s_parts = _halve(s, -2)
m_parts = _halve(mask, -2) if mask is not None else (None, None)
tiles = list(zip(s_parts, m_parts))
else: # quad (4)
top_s, bot_s = _halve(s, -2)
tl, tr = _halve(top_s, -1)
bl, br = _halve(bot_s, -1)
s_parts = [tl, tr, bl, br]
if mask is not None:
top_m, bot_m = _halve(mask, -2)
ml, mr = _halve(top_m, -1)
mbl, mbr = _halve(bot_m, -1)
m_parts = [ml, mr, mbl, mbr]
else:
m_parts = [None, None, None, None]
tiles = list(zip(s_parts, m_parts))
outputs = [_wrap(latent, sub_s, sub_m) for sub_s, sub_m in tiles]
while len(outputs) < 4:
outputs.append(None)
return tuple(outputs)
NODE_CLASS_MAPPINGS = {
"LatentSplit": LatentSplit,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LatentSplit": "Latent Split (2 / 4)",
}