fix: compute seam positions from actual tile grid
Replicate SplitImageToTileList's get_grid_coords logic to find real overlap regions between adjacent tiles. Fixes three bugs: 1. Bands were at overlap start instead of center (off by overlap/2) 2. Spurious bands generated beyond the actual tile grid 3. Edge tile seams placed at wrong position (ignoring boundary shift) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -26,27 +26,45 @@ class GenerateSeamMask:
|
|||||||
CATEGORY = "image/upscaling"
|
CATEGORY = "image/upscaling"
|
||||||
DESCRIPTION = "Generates a mask image with white bands at tile seam positions. Used for targeted seam fix denoising."
|
DESCRIPTION = "Generates a mask image with white bands at tile seam positions. Used for targeted seam fix denoising."
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_tile_positions(length, tile_size, overlap):
|
||||||
|
"""Compute 1D tile start/end positions, matching SplitImageToTileList's get_grid_coords."""
|
||||||
|
stride = max(1, tile_size - overlap)
|
||||||
|
positions = []
|
||||||
|
p = 0
|
||||||
|
while p < length:
|
||||||
|
p_end = min(p + tile_size, length)
|
||||||
|
p_start = max(0, p_end - tile_size)
|
||||||
|
positions.append((p_start, p_end))
|
||||||
|
if p_end >= length:
|
||||||
|
break
|
||||||
|
p += stride
|
||||||
|
return positions
|
||||||
|
|
||||||
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width):
|
def generate(self, image_width, image_height, tile_width, tile_height, overlap, seam_width):
|
||||||
mask = torch.zeros(1, image_height, image_width, 3)
|
mask = torch.zeros(1, image_height, image_width, 3)
|
||||||
|
|
||||||
stride_x = max(1, tile_width - overlap)
|
|
||||||
stride_y = max(1, tile_height - overlap)
|
|
||||||
half_w = seam_width // 2
|
half_w = seam_width // 2
|
||||||
|
|
||||||
# Vertical seam bands
|
# Compute actual tile grids (same logic as SplitImageToTileList)
|
||||||
x = stride_x
|
x_tiles = self._get_tile_positions(image_width, tile_width, overlap)
|
||||||
while x < image_width:
|
y_tiles = self._get_tile_positions(image_height, tile_height, overlap)
|
||||||
x_start = max(0, x - half_w)
|
|
||||||
x_end = min(image_width, x + half_w)
|
|
||||||
mask[:, :, x_start:x_end, :] = 1.0
|
|
||||||
x += stride_x
|
|
||||||
|
|
||||||
# Horizontal seam bands
|
# Vertical seam bands (between horizontally adjacent tiles)
|
||||||
y = stride_y
|
for i in range(len(x_tiles) - 1):
|
||||||
while y < image_height:
|
ovl_start = max(x_tiles[i][0], x_tiles[i + 1][0])
|
||||||
y_start = max(0, y - half_w)
|
ovl_end = min(x_tiles[i][1], x_tiles[i + 1][1])
|
||||||
y_end = min(image_height, y + half_w)
|
center = (ovl_start + ovl_end) // 2
|
||||||
|
x_start = max(0, center - half_w)
|
||||||
|
x_end = min(image_width, center + half_w)
|
||||||
|
mask[:, :, x_start:x_end, :] = 1.0
|
||||||
|
|
||||||
|
# Horizontal seam bands (between vertically adjacent tiles)
|
||||||
|
for i in range(len(y_tiles) - 1):
|
||||||
|
ovl_start = max(y_tiles[i][0], y_tiles[i + 1][0])
|
||||||
|
ovl_end = min(y_tiles[i][1], y_tiles[i + 1][1])
|
||||||
|
center = (ovl_start + ovl_end) // 2
|
||||||
|
y_start = max(0, center - half_w)
|
||||||
|
y_end = min(image_height, center + half_w)
|
||||||
mask[:, y_start:y_end, :, :] = 1.0
|
mask[:, y_start:y_end, :, :] = 1.0
|
||||||
y += stride_y
|
|
||||||
|
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|||||||
@@ -15,16 +15,20 @@ def test_output_shape():
|
|||||||
|
|
||||||
|
|
||||||
def test_seam_positions():
|
def test_seam_positions():
|
||||||
|
"""Seam bands should be centered at the overlap midpoint between adjacent tiles."""
|
||||||
node = GenerateSeamMask()
|
node = GenerateSeamMask()
|
||||||
result = node.generate(image_width=2048, image_height=2048,
|
result = node.generate(image_width=2048, image_height=2048,
|
||||||
tile_width=1024, tile_height=1024,
|
tile_width=1024, tile_height=1024,
|
||||||
overlap=128, seam_width=64)
|
overlap=128, seam_width=64)
|
||||||
mask = result[0]
|
mask = result[0]
|
||||||
# Stride = 1024 - 128 = 896
|
# Tiles: [0,1024), [896,1920), [1024,2048)
|
||||||
# Seams at x=896, 1792 and y=896, 1792
|
# Overlap between tile 0 and 1: [896, 1024), center=960
|
||||||
assert mask[0, 0, 896, 0].item() == 1.0, "Center of vertical seam should be white"
|
# Band should be at [928, 992)
|
||||||
assert mask[0, 896, 0, 0].item() == 1.0, "Center of horizontal seam should be white"
|
assert mask[0, 0, 960, 0].item() == 1.0, "Center of overlap (960) should be white"
|
||||||
|
assert mask[0, 960, 0, 0].item() == 1.0, "Horizontal seam center should be white"
|
||||||
assert mask[0, 0, 400, 0].item() == 0.0, "Far from any seam should be black"
|
assert mask[0, 0, 400, 0].item() == 0.0, "Far from any seam should be black"
|
||||||
|
# Old wrong position (stride=896) should NOT be in the band
|
||||||
|
assert mask[0, 0, 896, 0].item() == 0.0, "Start of overlap (896) should be outside the band"
|
||||||
|
|
||||||
|
|
||||||
def test_no_seams_single_tile():
|
def test_no_seams_single_tile():
|
||||||
@@ -37,17 +41,50 @@ def test_no_seams_single_tile():
|
|||||||
assert mask.sum().item() == 0.0, "Single tile image should have no seams"
|
assert mask.sum().item() == 0.0, "Single tile image should have no seams"
|
||||||
|
|
||||||
|
|
||||||
def test_seam_band_width():
|
def test_seam_band_width_no_overlap():
|
||||||
|
"""With overlap=0, seam center is at tile boundary."""
|
||||||
node = GenerateSeamMask()
|
node = GenerateSeamMask()
|
||||||
result = node.generate(image_width=2048, image_height=1024,
|
result = node.generate(image_width=2048, image_height=1024,
|
||||||
tile_width=1024, tile_height=1024,
|
tile_width=1024, tile_height=1024,
|
||||||
overlap=0, seam_width=64)
|
overlap=0, seam_width=64)
|
||||||
mask = result[0]
|
mask = result[0]
|
||||||
# Stride = 1024, seam at x=1024, band from 992 to 1056
|
# Tiles: [0,1024), [1024,2048). Overlap: [1024,1024) = empty.
|
||||||
assert mask[0, 0, 1023, 0].item() == 1.0, "Inside band should be white"
|
# Center at 1024, band [992, 1056)
|
||||||
|
assert mask[0, 0, 1024, 0].item() == 1.0, "Seam center should be white"
|
||||||
assert mask[0, 0, 991, 0].item() == 0.0, "Outside band should be black"
|
assert mask[0, 0, 991, 0].item() == 0.0, "Outside band should be black"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_spurious_bands():
|
||||||
|
"""Should not generate bands beyond the actual tile grid."""
|
||||||
|
node = GenerateSeamMask()
|
||||||
|
# 2816px with 1024 tiles, stride=896: 3 tiles, 2 seams
|
||||||
|
result = node.generate(image_width=2816, image_height=1024,
|
||||||
|
tile_width=1024, tile_height=1024,
|
||||||
|
overlap=128, seam_width=64)
|
||||||
|
mask = result[0]
|
||||||
|
# Tiles: [0,1024), [896,1920), [1792,2816) — 3 tiles, 2 vertical seams
|
||||||
|
# Seam 0-1: overlap [896,1024), center=960
|
||||||
|
# Seam 1-2: overlap [1792,1920), center=1856
|
||||||
|
assert mask[0, 0, 960, 0].item() == 1.0, "Seam 0-1 center should be white"
|
||||||
|
assert mask[0, 0, 1856, 0].item() == 1.0, "Seam 1-2 center should be white"
|
||||||
|
# x=2688 was a spurious band in the old code — should be black now
|
||||||
|
assert mask[0, 0, 2688, 0].item() == 0.0, "No spurious band beyond tile grid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_tile_seam_position():
|
||||||
|
"""Edge tile seam should be at the actual overlap center, not at n*stride."""
|
||||||
|
node = GenerateSeamMask()
|
||||||
|
# 2048px: tiles [0,1024), [896,1920), [1024,2048)
|
||||||
|
# Edge seam between tile 1 and 2: overlap [1024,1920), center=1472
|
||||||
|
result = node.generate(image_width=2048, image_height=1024,
|
||||||
|
tile_width=1024, tile_height=1024,
|
||||||
|
overlap=128, seam_width=64)
|
||||||
|
mask = result[0]
|
||||||
|
assert mask[0, 0, 1472, 0].item() == 1.0, "Edge tile seam center (1472) should be white"
|
||||||
|
# Old wrong position
|
||||||
|
assert mask[0, 0, 1792, 0].item() == 0.0, "Old position (1792) should be black"
|
||||||
|
|
||||||
|
|
||||||
def test_values_are_binary():
|
def test_values_are_binary():
|
||||||
node = GenerateSeamMask()
|
node = GenerateSeamMask()
|
||||||
result = node.generate(image_width=2048, image_height=2048,
|
result = node.generate(image_width=2048, image_height=2048,
|
||||||
@@ -62,6 +99,8 @@ if __name__ == "__main__":
|
|||||||
test_output_shape()
|
test_output_shape()
|
||||||
test_seam_positions()
|
test_seam_positions()
|
||||||
test_no_seams_single_tile()
|
test_no_seams_single_tile()
|
||||||
test_seam_band_width()
|
test_seam_band_width_no_overlap()
|
||||||
|
test_no_spurious_bands()
|
||||||
|
test_edge_tile_seam_position()
|
||||||
test_values_are_binary()
|
test_values_are_binary()
|
||||||
print("All tests passed!")
|
print("All tests passed!")
|
||||||
|
|||||||
Reference in New Issue
Block a user