Add Latent Split node (2/4 spatial split) for ComfyUI
Splits a [B,C,H,W] latent into left/right or top/bottom halves, or four quadrants. Model-agnostic (SD, SDXL, Flux / FLUX.1 Krea); preserves batch dimension and noise_mask. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,121 @@
|
||||
"""Latent splitter nodes for ComfyUI.
|
||||
|
||||
Splits a latent along its spatial dimensions into halves (left/right or
|
||||
top/bottom) or quarters. The latent tensor has shape ``[B, C, H, W]`` where H/W
|
||||
are 1/8 of the image size, so this is model-agnostic: it only touches the H/W
|
||||
axes and works for SD/SDXL (4 channels) and Flux / FLUX.1 Krea (16 channels)
|
||||
alike.
|
||||
|
||||
Outputs are filled in row-major (reading) order:
|
||||
|
||||
left / right -> latent_1 = left, latent_2 = right
|
||||
top / bottom -> latent_1 = top, latent_2 = bottom
|
||||
quad (4) -> latent_1 = top-left, latent_2 = top-right,
|
||||
latent_3 = bottom-left, latent_4 = bottom-right
|
||||
|
||||
Unused slots (3 and 4 in the two-way modes) return ``None``; leave them
|
||||
unconnected.
|
||||
"""
|
||||
|
||||
SPLIT_MODES = ["left / right", "top / bottom", "quad (4)"]
|
||||
|
||||
|
||||
def _halve(t, axis):
|
||||
"""Split a tensor in two along ``axis`` at its midpoint (floor division).
|
||||
|
||||
Returns views; the caller makes them contiguous when wrapping into a latent.
|
||||
"""
|
||||
n = t.shape[axis]
|
||||
mid = n // 2
|
||||
return t.narrow(axis, 0, mid), t.narrow(axis, mid, n - mid)
|
||||
|
||||
|
||||
def _wrap(latent, samples, mask):
|
||||
"""Copy the latent dict, swapping in a sub-tile of samples (and its mask)."""
|
||||
out = {k: v for k, v in latent.items() if k not in ("samples", "noise_mask")}
|
||||
out["samples"] = samples.contiguous()
|
||||
if mask is not None:
|
||||
out["noise_mask"] = mask.contiguous()
|
||||
return out
|
||||
|
||||
|
||||
class LatentSplit:
|
||||
"""Split one latent into 2 (left/right or top/bottom) or 4 tiles."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT", {"tooltip": "Latent to split."}),
|
||||
"mode": (
|
||||
SPLIT_MODES,
|
||||
{
|
||||
"default": "quad (4)",
|
||||
"tooltip": "left/right splits the width, top/bottom "
|
||||
"splits the height, quad splits both into 4 tiles.",
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT", "LATENT", "LATENT", "LATENT")
|
||||
RETURN_NAMES = ("latent_1", "latent_2", "latent_3", "latent_4")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"left / top / top-left tile",
|
||||
"right / bottom / top-right tile",
|
||||
"bottom-left tile (quad only, else None)",
|
||||
"bottom-right tile (quad only, else None)",
|
||||
)
|
||||
FUNCTION = "split"
|
||||
CATEGORY = "latent"
|
||||
DESCRIPTION = (
|
||||
"Split a latent into halves or quarters along its spatial axes.\n"
|
||||
"Model-agnostic (works with SD, SDXL, Flux / FLUX.1 Krea).\n\n"
|
||||
"Outputs fill in reading order:\n"
|
||||
" left/right -> 1=left, 2=right\n"
|
||||
" top/bottom -> 1=top, 2=bottom\n"
|
||||
" quad -> 1=TL, 2=TR, 3=BL, 4=BR\n"
|
||||
"Odd latent dimensions split with the extra row/column going to the "
|
||||
"second tile."
|
||||
)
|
||||
|
||||
def split(self, samples, mode):
|
||||
latent = samples
|
||||
s = latent["samples"] # [B, C, H, W]
|
||||
mask = latent.get("noise_mask") # optional [B, 1, H', W']
|
||||
|
||||
if mode == "left / right":
|
||||
s_parts = _halve(s, -1)
|
||||
m_parts = _halve(mask, -1) if mask is not None else (None, None)
|
||||
tiles = list(zip(s_parts, m_parts))
|
||||
elif mode == "top / bottom":
|
||||
s_parts = _halve(s, -2)
|
||||
m_parts = _halve(mask, -2) if mask is not None else (None, None)
|
||||
tiles = list(zip(s_parts, m_parts))
|
||||
else: # quad (4)
|
||||
top_s, bot_s = _halve(s, -2)
|
||||
tl, tr = _halve(top_s, -1)
|
||||
bl, br = _halve(bot_s, -1)
|
||||
s_parts = [tl, tr, bl, br]
|
||||
if mask is not None:
|
||||
top_m, bot_m = _halve(mask, -2)
|
||||
ml, mr = _halve(top_m, -1)
|
||||
mbl, mbr = _halve(bot_m, -1)
|
||||
m_parts = [ml, mr, mbl, mbr]
|
||||
else:
|
||||
m_parts = [None, None, None, None]
|
||||
tiles = list(zip(s_parts, m_parts))
|
||||
|
||||
outputs = [_wrap(latent, sub_s, sub_m) for sub_s, sub_m in tiles]
|
||||
while len(outputs) < 4:
|
||||
outputs.append(None)
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentSplit": LatentSplit,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LatentSplit": "Latent Split (2 / 4)",
|
||||
}
|
||||
Reference in New Issue
Block a user