Replace per-pixel Python loop with vectorized torch.arange + slice operations. Fix DifferentialDiffusion node position to avoid visual overlap with SplitImageToTileList node 14 on the canvas. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
96 lines
4.6 KiB
Python
96 lines
4.6 KiB
Python
import torch
|
|
|
|
|
|
class GenerateSeamMask:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image_width": ("INT", {"default": 2048, "min": 64, "max": 16384, "step": 1,
|
|
"tooltip": "Width of the image (from GetImageSize)."}),
|
|
"image_height": ("INT", {"default": 2048, "min": 64, "max": 16384, "step": 1,
|
|
"tooltip": "Height of the image (from GetImageSize)."}),
|
|
"tile_width": ("INT", {"default": 1024, "min": 64, "max": 8192, "step": 8,
|
|
"tooltip": "Tile width used in the main tiled redraw pass."}),
|
|
"tile_height": ("INT", {"default": 1024, "min": 64, "max": 8192, "step": 8,
|
|
"tooltip": "Tile height used in the main tiled redraw pass."}),
|
|
"overlap": ("INT", {"default": 128, "min": 0, "max": 4096, "step": 1,
|
|
"tooltip": "Overlap used in the main tiled redraw pass."}),
|
|
"seam_width": ("INT", {"default": 64, "min": 8, "max": 512, "step": 8,
|
|
"tooltip": "Width of the seam bands to fix (in pixels)."}),
|
|
"mode": (["binary", "gradient"], {"default": "binary",
|
|
"tooltip": "binary: hard 0/1 mask. gradient: linear falloff for use with Differential Diffusion."}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "generate"
|
|
CATEGORY = "image/upscaling"
|
|
DESCRIPTION = "Generates a mask image with white bands at tile seam positions. Used for targeted seam fix denoising."
|
|
|
|
@staticmethod
|
|
def _get_tile_positions(length, tile_size, overlap):
|
|
"""Compute 1D tile start/end positions, matching SplitImageToTileList's get_grid_coords."""
|
|
stride = max(1, tile_size - overlap)
|
|
positions = []
|
|
p = 0
|
|
while p < length:
|
|
p_end = min(p + tile_size, length)
|
|
p_start = max(0, p_end - tile_size)
|
|
positions.append((p_start, p_end))
|
|
if p_end >= length:
|
|
break
|
|
p += stride
|
|
return positions
|
|
|
|
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width, mode="binary"):
|
|
mask = torch.zeros(1, image_height, image_width, 3)
|
|
half_w = seam_width // 2
|
|
|
|
# Compute actual tile grids (same logic as SplitImageToTileList)
|
|
x_tiles = self._get_tile_positions(image_width, tile_width, overlap)
|
|
y_tiles = self._get_tile_positions(image_height, tile_height, overlap)
|
|
|
|
if mode == "gradient":
|
|
# Build 1D linear ramps for each seam, then take max across all bands
|
|
# Vertical seam bands
|
|
for i in range(len(x_tiles) - 1):
|
|
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
|
|
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
|
|
center = (ovl_start + ovl_end) // 2
|
|
x_start = max(0, center - half_w)
|
|
x_end = min(image_width, center + half_w)
|
|
xs = torch.arange(x_start, x_end, dtype=torch.float32)
|
|
vals = (1.0 - (xs - center).abs() / half_w).view(1, 1, -1, 1)
|
|
mask[:, :, x_start:x_end, :] = torch.max(mask[:, :, x_start:x_end, :], vals)
|
|
|
|
# Horizontal seam bands
|
|
for i in range(len(y_tiles) - 1):
|
|
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
|
|
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
|
|
center = (ovl_start + ovl_end) // 2
|
|
y_start = max(0, center - half_w)
|
|
y_end = min(image_height, center + half_w)
|
|
ys = torch.arange(y_start, y_end, dtype=torch.float32)
|
|
vals = (1.0 - (ys - center).abs() / half_w).view(1, -1, 1, 1)
|
|
mask[:, y_start:y_end, :, :] = torch.max(mask[:, y_start:y_end, :, :], vals)
|
|
else:
|
|
# Binary mode (original behavior)
|
|
for i in range(len(x_tiles) - 1):
|
|
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
|
|
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
|
|
center = (ovl_start + ovl_end) // 2
|
|
x_start = max(0, center - half_w)
|
|
x_end = min(image_width, center + half_w)
|
|
mask[:, :, x_start:x_end, :] = 1.0
|
|
|
|
for i in range(len(y_tiles) - 1):
|
|
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
|
|
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
|
|
center = (ovl_start + ovl_end) // 2
|
|
y_start = max(0, center - half_w)
|
|
y_end = min(image_height, center + half_w)
|
|
mask[:, y_start:y_end, :, :] = 1.0
|
|
|
|
return (mask,)
|