From 65253fd1bc5a2720b6c4939c498dd207a53368e7 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 25 Feb 2026 16:05:26 +0100 Subject: [PATCH] fix: compute seam positions from actual tile grid Replicate SplitImageToTileList's get_grid_coords logic to find real overlap regions between adjacent tiles. Fixes three bugs: 1. Bands were at overlap start instead of center (off by overlap/2) 2. Spurious bands generated beyond the actual tile grid 3. Edge tile seams placed at wrong position (ignoring boundary shift) Co-Authored-By: Claude Opus 4.6 --- seam_mask_node.py | 50 +++++++++++++++++++++++++------------ tests/test_seam_mask.py | 55 +++++++++++++++++++++++++++++++++++------ 2 files changed, 81 insertions(+), 24 deletions(-) diff --git a/seam_mask_node.py b/seam_mask_node.py index 43423f5..7a76a2c 100644 --- a/seam_mask_node.py +++ b/seam_mask_node.py @@ -26,27 +26,45 @@ class GenerateSeamMask: CATEGORY = "image/upscaling" DESCRIPTION = "Generates a mask image with white bands at tile seam positions. Used for targeted seam fix denoising." + @staticmethod + def _get_tile_positions(length, tile_size, overlap): + """Compute 1D tile start/end positions, matching SplitImageToTileList's get_grid_coords.""" + stride = max(1, tile_size - overlap) + positions = [] + p = 0 + while p < length: + p_end = min(p + tile_size, length) + p_start = max(0, p_end - tile_size) + positions.append((p_start, p_end)) + if p_end >= length: + break + p += stride + return positions + def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width): mask = torch.zeros(1, image_height, image_width, 3) - - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) half_w = seam_width // 2 - # Vertical seam bands - x = stride_x - while x < image_width: - x_start = max(0, x - half_w) - x_end = min(image_width, x + half_w) - mask[:, :, x_start:x_end, :] = 1.0 - x += stride_x + # Compute actual tile grids (same logic as SplitImageToTileList) + x_tiles = self._get_tile_positions(image_width, tile_width, overlap) + y_tiles = self._get_tile_positions(image_height, tile_height, overlap) - # Horizontal seam bands - y = stride_y - while y < image_height: - y_start = max(0, y - half_w) - y_end = min(image_height, y + half_w) + # Vertical seam bands (between horizontally adjacent tiles) + for i in range(len(x_tiles) - 1): + ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0]) + ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1]) + center = (ovl_start + ovl_end) // 2 + x_start = max(0, center - half_w) + x_end = min(image_width, center + half_w) + mask[:, :, x_start:x_end, :] = 1.0 + + # Horizontal seam bands (between vertically adjacent tiles) + for i in range(len(y_tiles) - 1): + ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0]) + ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1]) + center = (ovl_start + ovl_end) // 2 + y_start = max(0, center - half_w) + y_end = min(image_height, center + half_w) mask[:, y_start:y_end, :, :] = 1.0 - y += stride_y return (mask,) diff --git a/tests/test_seam_mask.py b/tests/test_seam_mask.py index 51585de..1780b5f 100644 --- a/tests/test_seam_mask.py +++ b/tests/test_seam_mask.py @@ -15,16 +15,20 @@ def test_output_shape(): def test_seam_positions(): + """Seam bands should be centered at the overlap midpoint between adjacent tiles.""" node = GenerateSeamMask() result = node.generate(image_width=2048, image_height=2048, tile_width=1024, tile_height=1024, overlap=128, seam_width=64) mask = result[0] - # Stride = 1024 - 128 = 896 - # Seams at x=896, 1792 and y=896, 1792 - assert mask[0, 0, 896, 0].item() == 1.0, "Center of vertical seam should be white" - assert mask[0, 896, 0, 0].item() == 1.0, "Center of horizontal seam should be white" + # Tiles: [0,1024), [896,1920), [1024,2048) + # Overlap between tile 0 and 1: [896, 1024), center=960 + # Band should be at [928, 992) + assert mask[0, 0, 960, 0].item() == 1.0, "Center of overlap (960) should be white" + assert mask[0, 960, 0, 0].item() == 1.0, "Horizontal seam center should be white" assert mask[0, 0, 400, 0].item() == 0.0, "Far from any seam should be black" + # Old wrong position (stride=896) should NOT be in the band + assert mask[0, 0, 896, 0].item() == 0.0, "Start of overlap (896) should be outside the band" def test_no_seams_single_tile(): @@ -37,17 +41,50 @@ def test_no_seams_single_tile(): assert mask.sum().item() == 0.0, "Single tile image should have no seams" -def test_seam_band_width(): +def test_seam_band_width_no_overlap(): + """With overlap=0, seam center is at tile boundary.""" node = GenerateSeamMask() result = node.generate(image_width=2048, image_height=1024, tile_width=1024, tile_height=1024, overlap=0, seam_width=64) mask = result[0] - # Stride = 1024, seam at x=1024, band from 992 to 1056 - assert mask[0, 0, 1023, 0].item() == 1.0, "Inside band should be white" + # Tiles: [0,1024), [1024,2048). Overlap: [1024,1024) = empty. + # Center at 1024, band [992, 1056) + assert mask[0, 0, 1024, 0].item() == 1.0, "Seam center should be white" assert mask[0, 0, 991, 0].item() == 0.0, "Outside band should be black" +def test_no_spurious_bands(): + """Should not generate bands beyond the actual tile grid.""" + node = GenerateSeamMask() + # 2816px with 1024 tiles, stride=896: 3 tiles, 2 seams + result = node.generate(image_width=2816, image_height=1024, + tile_width=1024, tile_height=1024, + overlap=128, seam_width=64) + mask = result[0] + # Tiles: [0,1024), [896,1920), [1792,2816) — 3 tiles, 2 vertical seams + # Seam 0-1: overlap [896,1024), center=960 + # Seam 1-2: overlap [1792,1920), center=1856 + assert mask[0, 0, 960, 0].item() == 1.0, "Seam 0-1 center should be white" + assert mask[0, 0, 1856, 0].item() == 1.0, "Seam 1-2 center should be white" + # x=2688 was a spurious band in the old code — should be black now + assert mask[0, 0, 2688, 0].item() == 0.0, "No spurious band beyond tile grid" + + +def test_edge_tile_seam_position(): + """Edge tile seam should be at the actual overlap center, not at n*stride.""" + node = GenerateSeamMask() + # 2048px: tiles [0,1024), [896,1920), [1024,2048) + # Edge seam between tile 1 and 2: overlap [1024,1920), center=1472 + result = node.generate(image_width=2048, image_height=1024, + tile_width=1024, tile_height=1024, + overlap=128, seam_width=64) + mask = result[0] + assert mask[0, 0, 1472, 0].item() == 1.0, "Edge tile seam center (1472) should be white" + # Old wrong position + assert mask[0, 0, 1792, 0].item() == 0.0, "Old position (1792) should be black" + + def test_values_are_binary(): node = GenerateSeamMask() result = node.generate(image_width=2048, image_height=2048, @@ -62,6 +99,8 @@ if __name__ == "__main__": test_output_shape() test_seam_positions() test_no_seams_single_tile() - test_seam_band_width() + test_seam_band_width_no_overlap() + test_no_spurious_bands() + test_edge_tile_seam_position() test_values_are_binary() print("All tests passed!")