diff --git a/seam_mask_node.py b/seam_mask_node.py index 7a76a2c..ffb9f76 100644 --- a/seam_mask_node.py +++ b/seam_mask_node.py @@ -18,6 +18,8 @@ class GenerateSeamMask: "tooltip": "Overlap used in the main tiled redraw pass."}), "seam_width": ("INT", {"default": 64, "min": 8, "max": 512, "step": 8, "tooltip": "Width of the seam bands to fix (in pixels)."}), + "mode": (["binary", "gradient"], {"default": "binary", + "tooltip": "binary: hard 0/1 mask. gradient: linear falloff for use with Differential Diffusion."}), } } @@ -41,7 +43,7 @@ class GenerateSeamMask: p += stride return positions - def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width): + def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width, mode="binary"): mask = torch.zeros(1, image_height, image_width, 3) half_w = seam_width // 2 @@ -49,22 +51,45 @@ class GenerateSeamMask: x_tiles = self._get_tile_positions(image_width, tile_width, overlap) y_tiles = self._get_tile_positions(image_height, tile_height, overlap) - # Vertical seam bands (between horizontally adjacent tiles) - for i in range(len(x_tiles) - 1): - ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0]) - ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1]) - center = (ovl_start + ovl_end) // 2 - x_start = max(0, center - half_w) - x_end = min(image_width, center + half_w) - mask[:, :, x_start:x_end, :] = 1.0 + if mode == "gradient": + # Build 1D linear ramps for each seam, then take max across all bands + # Vertical seam bands + for i in range(len(x_tiles) - 1): + ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0]) + ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1]) + center = (ovl_start + ovl_end) // 2 + x_start = max(0, center - half_w) + x_end = min(image_width, center + half_w) + for x in range(x_start, x_end): + val = 1.0 - abs(x - center) / half_w + mask[:, :, x, :] = torch.max(mask[:, :, x, :], torch.tensor(val)) - # Horizontal seam bands (between vertically adjacent tiles) - for i in range(len(y_tiles) - 1): - ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0]) - ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1]) - center = (ovl_start + ovl_end) // 2 - y_start = max(0, center - half_w) - y_end = min(image_height, center + half_w) - mask[:, y_start:y_end, :, :] = 1.0 + # Horizontal seam bands + for i in range(len(y_tiles) - 1): + ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0]) + ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1]) + center = (ovl_start + ovl_end) // 2 + y_start = max(0, center - half_w) + y_end = min(image_height, center + half_w) + for y in range(y_start, y_end): + val = 1.0 - abs(y - center) / half_w + mask[:, y, :, :] = torch.max(mask[:, y, :, :], torch.tensor(val)) + else: + # Binary mode (original behavior) + for i in range(len(x_tiles) - 1): + ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0]) + ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1]) + center = (ovl_start + ovl_end) // 2 + x_start = max(0, center - half_w) + x_end = min(image_width, center + half_w) + mask[:, :, x_start:x_end, :] = 1.0 + + for i in range(len(y_tiles) - 1): + ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0]) + ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1]) + center = (ovl_start + ovl_end) // 2 + y_start = max(0, center - half_w) + y_end = min(image_height, center + half_w) + mask[:, y_start:y_end, :, :] = 1.0 return (mask,)