feat: add gradient mode to GenerateSeamMask for differential diffusion

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-25 16:37:17 +01:00
parent d46192295b
commit cd00843b2e

View File

@@ -18,6 +18,8 @@ class GenerateSeamMask:
"tooltip": "Overlap used in the main tiled redraw pass."}),
"seam_width": ("INT", {"default": 64, "min": 8, "max": 512, "step": 8,
"tooltip": "Width of the seam bands to fix (in pixels)."}),
"mode": (["binary", "gradient"], {"default": "binary",
"tooltip": "binary: hard 0/1 mask. gradient: linear falloff for use with Differential Diffusion."}),
}
}
@@ -41,7 +43,7 @@ class GenerateSeamMask:
p += stride
return positions
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width):
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width, mode="binary"):
mask = torch.zeros(1, image_height, image_width, 3)
half_w = seam_width // 2
@@ -49,22 +51,45 @@ class GenerateSeamMask:
x_tiles = self._get_tile_positions(image_width, tile_width, overlap)
y_tiles = self._get_tile_positions(image_height, tile_height, overlap)
# Vertical seam bands (between horizontally adjacent tiles)
for i in range(len(x_tiles) - 1):
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
center = (ovl_start + ovl_end) // 2
x_start = max(0, center - half_w)
x_end = min(image_width, center + half_w)
mask[:, :, x_start:x_end, :] = 1.0
if mode == "gradient":
# Build 1D linear ramps for each seam, then take max across all bands
# Vertical seam bands
for i in range(len(x_tiles) - 1):
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
center = (ovl_start + ovl_end) // 2
x_start = max(0, center - half_w)
x_end = min(image_width, center + half_w)
for x in range(x_start, x_end):
val = 1.0 - abs(x - center) / half_w
mask[:, :, x, :] = torch.max(mask[:, :, x, :], torch.tensor(val))
# Horizontal seam bands (between vertically adjacent tiles)
for i in range(len(y_tiles) - 1):
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
center = (ovl_start + ovl_end) // 2
y_start = max(0, center - half_w)
y_end = min(image_height, center + half_w)
mask[:, y_start:y_end, :, :] = 1.0
# Horizontal seam bands
for i in range(len(y_tiles) - 1):
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
center = (ovl_start + ovl_end) // 2
y_start = max(0, center - half_w)
y_end = min(image_height, center + half_w)
for y in range(y_start, y_end):
val = 1.0 - abs(y - center) / half_w
mask[:, y, :, :] = torch.max(mask[:, y, :, :], torch.tensor(val))
else:
# Binary mode (original behavior)
for i in range(len(x_tiles) - 1):
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
center = (ovl_start + ovl_end) // 2
x_start = max(0, center - half_w)
x_end = min(image_width, center + half_w)
mask[:, :, x_start:x_end, :] = 1.0
for i in range(len(y_tiles) - 1):
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
center = (ovl_start + ovl_end) // 2
y_start = max(0, center - half_w)
y_end = min(image_height, center + half_w)
mask[:, y_start:y_end, :, :] = 1.0
return (mask,)