feat: add gradient mode to GenerateSeamMask for differential diffusion
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -18,6 +18,8 @@ class GenerateSeamMask:
|
|||||||
"tooltip": "Overlap used in the main tiled redraw pass."}),
|
"tooltip": "Overlap used in the main tiled redraw pass."}),
|
||||||
"seam_width": ("INT", {"default": 64, "min": 8, "max": 512, "step": 8,
|
"seam_width": ("INT", {"default": 64, "min": 8, "max": 512, "step": 8,
|
||||||
"tooltip": "Width of the seam bands to fix (in pixels)."}),
|
"tooltip": "Width of the seam bands to fix (in pixels)."}),
|
||||||
|
"mode": (["binary", "gradient"], {"default": "binary",
|
||||||
|
"tooltip": "binary: hard 0/1 mask. gradient: linear falloff for use with Differential Diffusion."}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +43,7 @@ class GenerateSeamMask:
|
|||||||
p += stride
|
p += stride
|
||||||
return positions
|
return positions
|
||||||
|
|
||||||
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width):
|
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width, mode="binary"):
|
||||||
mask = torch.zeros(1, image_height, image_width, 3)
|
mask = torch.zeros(1, image_height, image_width, 3)
|
||||||
half_w = seam_width // 2
|
half_w = seam_width // 2
|
||||||
|
|
||||||
@@ -49,22 +51,45 @@ class GenerateSeamMask:
|
|||||||
x_tiles = self._get_tile_positions(image_width, tile_width, overlap)
|
x_tiles = self._get_tile_positions(image_width, tile_width, overlap)
|
||||||
y_tiles = self._get_tile_positions(image_height, tile_height, overlap)
|
y_tiles = self._get_tile_positions(image_height, tile_height, overlap)
|
||||||
|
|
||||||
# Vertical seam bands (between horizontally adjacent tiles)
|
if mode == "gradient":
|
||||||
for i in range(len(x_tiles) - 1):
|
# Build 1D linear ramps for each seam, then take max across all bands
|
||||||
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
|
# Vertical seam bands
|
||||||
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
|
for i in range(len(x_tiles) - 1):
|
||||||
center = (ovl_start + ovl_end) // 2
|
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
|
||||||
x_start = max(0, center - half_w)
|
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
|
||||||
x_end = min(image_width, center + half_w)
|
center = (ovl_start + ovl_end) // 2
|
||||||
mask[:, :, x_start:x_end, :] = 1.0
|
x_start = max(0, center - half_w)
|
||||||
|
x_end = min(image_width, center + half_w)
|
||||||
|
for x in range(x_start, x_end):
|
||||||
|
val = 1.0 - abs(x - center) / half_w
|
||||||
|
mask[:, :, x, :] = torch.max(mask[:, :, x, :], torch.tensor(val))
|
||||||
|
|
||||||
# Horizontal seam bands (between vertically adjacent tiles)
|
# Horizontal seam bands
|
||||||
for i in range(len(y_tiles) - 1):
|
for i in range(len(y_tiles) - 1):
|
||||||
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
|
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
|
||||||
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
|
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
|
||||||
center = (ovl_start + ovl_end) // 2
|
center = (ovl_start + ovl_end) // 2
|
||||||
y_start = max(0, center - half_w)
|
y_start = max(0, center - half_w)
|
||||||
y_end = min(image_height, center + half_w)
|
y_end = min(image_height, center + half_w)
|
||||||
mask[:, y_start:y_end, :, :] = 1.0
|
for y in range(y_start, y_end):
|
||||||
|
val = 1.0 - abs(y - center) / half_w
|
||||||
|
mask[:, y, :, :] = torch.max(mask[:, y, :, :], torch.tensor(val))
|
||||||
|
else:
|
||||||
|
# Binary mode (original behavior)
|
||||||
|
for i in range(len(x_tiles) - 1):
|
||||||
|
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
|
||||||
|
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
|
||||||
|
center = (ovl_start + ovl_end) // 2
|
||||||
|
x_start = max(0, center - half_w)
|
||||||
|
x_end = min(image_width, center + half_w)
|
||||||
|
mask[:, :, x_start:x_end, :] = 1.0
|
||||||
|
|
||||||
|
for i in range(len(y_tiles) - 1):
|
||||||
|
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
|
||||||
|
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
|
||||||
|
center = (ovl_start + ovl_end) // 2
|
||||||
|
y_start = max(0, center - half_w)
|
||||||
|
y_end = min(image_height, center + half_w)
|
||||||
|
mask[:, y_start:y_end, :, :] = 1.0
|
||||||
|
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|||||||
Reference in New Issue
Block a user