291a0a1f1c
Splits a [B,C,H,W] latent into left/right or top/bottom halves, or four quadrants. Model-agnostic (SD, SDXL, Flux / FLUX.1 Krea); preserves batch dimension and noise_mask. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
"""Latent splitter nodes for ComfyUI.
|
|
|
|
Splits a latent along its spatial dimensions into halves (left/right or
|
|
top/bottom) or quarters. The latent tensor has shape ``[B, C, H, W]`` where H/W
|
|
are 1/8 of the image size, so this is model-agnostic: it only touches the H/W
|
|
axes and works for SD/SDXL (4 channels) and Flux / FLUX.1 Krea (16 channels)
|
|
alike.
|
|
|
|
Outputs are filled in row-major (reading) order:
|
|
|
|
left / right -> latent_1 = left, latent_2 = right
|
|
top / bottom -> latent_1 = top, latent_2 = bottom
|
|
quad (4) -> latent_1 = top-left, latent_2 = top-right,
|
|
latent_3 = bottom-left, latent_4 = bottom-right
|
|
|
|
Unused slots (3 and 4 in the two-way modes) return ``None``; leave them
|
|
unconnected.
|
|
"""
|
|
|
|
SPLIT_MODES = ["left / right", "top / bottom", "quad (4)"]
|
|
|
|
|
|
def _halve(t, axis):
|
|
"""Split a tensor in two along ``axis`` at its midpoint (floor division).
|
|
|
|
Returns views; the caller makes them contiguous when wrapping into a latent.
|
|
"""
|
|
n = t.shape[axis]
|
|
mid = n // 2
|
|
return t.narrow(axis, 0, mid), t.narrow(axis, mid, n - mid)
|
|
|
|
|
|
def _wrap(latent, samples, mask):
|
|
"""Copy the latent dict, swapping in a sub-tile of samples (and its mask)."""
|
|
out = {k: v for k, v in latent.items() if k not in ("samples", "noise_mask")}
|
|
out["samples"] = samples.contiguous()
|
|
if mask is not None:
|
|
out["noise_mask"] = mask.contiguous()
|
|
return out
|
|
|
|
|
|
class LatentSplit:
|
|
"""Split one latent into 2 (left/right or top/bottom) or 4 tiles."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"samples": ("LATENT", {"tooltip": "Latent to split."}),
|
|
"mode": (
|
|
SPLIT_MODES,
|
|
{
|
|
"default": "quad (4)",
|
|
"tooltip": "left/right splits the width, top/bottom "
|
|
"splits the height, quad splits both into 4 tiles.",
|
|
},
|
|
),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("LATENT", "LATENT", "LATENT", "LATENT")
|
|
RETURN_NAMES = ("latent_1", "latent_2", "latent_3", "latent_4")
|
|
OUTPUT_TOOLTIPS = (
|
|
"left / top / top-left tile",
|
|
"right / bottom / top-right tile",
|
|
"bottom-left tile (quad only, else None)",
|
|
"bottom-right tile (quad only, else None)",
|
|
)
|
|
FUNCTION = "split"
|
|
CATEGORY = "latent"
|
|
DESCRIPTION = (
|
|
"Split a latent into halves or quarters along its spatial axes.\n"
|
|
"Model-agnostic (works with SD, SDXL, Flux / FLUX.1 Krea).\n\n"
|
|
"Outputs fill in reading order:\n"
|
|
" left/right -> 1=left, 2=right\n"
|
|
" top/bottom -> 1=top, 2=bottom\n"
|
|
" quad -> 1=TL, 2=TR, 3=BL, 4=BR\n"
|
|
"Odd latent dimensions split with the extra row/column going to the "
|
|
"second tile."
|
|
)
|
|
|
|
def split(self, samples, mode):
|
|
latent = samples
|
|
s = latent["samples"] # [B, C, H, W]
|
|
mask = latent.get("noise_mask") # optional [B, 1, H', W']
|
|
|
|
if mode == "left / right":
|
|
s_parts = _halve(s, -1)
|
|
m_parts = _halve(mask, -1) if mask is not None else (None, None)
|
|
tiles = list(zip(s_parts, m_parts))
|
|
elif mode == "top / bottom":
|
|
s_parts = _halve(s, -2)
|
|
m_parts = _halve(mask, -2) if mask is not None else (None, None)
|
|
tiles = list(zip(s_parts, m_parts))
|
|
else: # quad (4)
|
|
top_s, bot_s = _halve(s, -2)
|
|
tl, tr = _halve(top_s, -1)
|
|
bl, br = _halve(bot_s, -1)
|
|
s_parts = [tl, tr, bl, br]
|
|
if mask is not None:
|
|
top_m, bot_m = _halve(mask, -2)
|
|
ml, mr = _halve(top_m, -1)
|
|
mbl, mbr = _halve(bot_m, -1)
|
|
m_parts = [ml, mr, mbl, mbr]
|
|
else:
|
|
m_parts = [None, None, None, None]
|
|
tiles = list(zip(s_parts, m_parts))
|
|
|
|
outputs = [_wrap(latent, sub_s, sub_m) for sub_s, sub_m in tiles]
|
|
while len(outputs) < 4:
|
|
outputs.append(None)
|
|
return tuple(outputs)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LatentSplit": LatentSplit,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"LatentSplit": "Latent Split (2 / 4)",
|
|
}
|