Add Bidirectional Extend, Frame Interpolation, and Replace/Inpaint modes

Three new modes for the VACE Mask Generator node, bringing the total to 8.
Bidirectional generates before and after the clip, Frame Interpolation
inserts frames between each source pair, and Replace/Inpaint regenerates
a region in-place. All reuse existing inputs with mode-specific semantics.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-19 13:23:12 +01:00
parent c44e9dd62a
commit d8143ea889
2 changed files with 111 additions and 8 deletions

View File

@@ -55,10 +55,13 @@ class VACEMaskGenerator:
"Middle Extend",
"Edge Extend",
"Join Extend",
"Bidirectional Extend",
"Frame Interpolation",
"Replace/Inpaint",
],
{
"default": "End Extend",
"description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves.",
"description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves. Bidirectional: generate before AND after clip. Frame Interpolation: insert generated frames between each source pair. Replace/Inpaint: regenerate a range of frames in-place.",
},
),
"target_frames": (
@@ -67,7 +70,7 @@ class VACEMaskGenerator:
"default": 81,
"min": 1,
"max": 10000,
"description": "Total output frame count for mask and control_frames.",
"description": "Total output frame count for mask and control_frames. Unused by Frame Interpolation and Replace/Inpaint.",
},
),
"split_index": (
@@ -76,7 +79,7 @@ class VACEMaskGenerator:
"default": 0,
"min": -10000,
"max": 10000,
"description": "Where to split the source. End: trim from end (e.g. -16). Pre: reference frames from start (e.g. 24). Middle: split frame index. Unused by Edge/Join.",
"description": "Where to split the source. End: trim from end (e.g. -16). Pre: reference frames from start (e.g. 24). Middle: split frame index. Unused by Edge/Join. Bidirectional: frames before clip (0 = even split). Frame Interpolation: new frames per gap. Replace/Inpaint: start index of replace region.",
},
),
"edge_frames": (
@@ -85,7 +88,7 @@ class VACEMaskGenerator:
"default": 8,
"min": 1,
"max": 10000,
"description": "Number of edge frames to use for Edge and Join modes. Unused by End/Pre/Middle.",
"description": "Number of edge frames to use for Edge and Join modes. Unused by End/Pre/Middle. Replace/Inpaint: number of frames to replace.",
},
),
}
@@ -159,6 +162,43 @@ class VACEMaskGenerator:
control_frames = torch.cat([part_2, solid(frames_to_generate, GREY), part_3], dim=0)
return (mask, control_frames, safe(part_1), safe(part_2), safe(part_3), safe(part_4), frames_to_generate)
elif mode == "Bidirectional Extend":
frames_to_generate = max(0, target_frames - B)
if split_index > 0:
pre_count = min(split_index, frames_to_generate)
else:
pre_count = frames_to_generate // 2
post_count = frames_to_generate - pre_count
mask = torch.cat([solid(pre_count, WHITE), solid(B, BLACK), solid(post_count, WHITE)], dim=0)
control_frames = torch.cat([solid(pre_count, GREY), source_clip, solid(post_count, GREY)], dim=0)
return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate)
elif mode == "Frame Interpolation":
step = max(split_index, 1)
frames_to_generate = (B - 1) * step
mask_parts = []
ctrl_parts = []
for i in range(B):
mask_parts.append(solid(1, BLACK))
ctrl_parts.append(source_clip[i:i+1])
if i < B - 1:
mask_parts.append(solid(step, WHITE))
ctrl_parts.append(solid(step, GREY))
mask = torch.cat(mask_parts, dim=0)
control_frames = torch.cat(ctrl_parts, dim=0)
return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate)
elif mode == "Replace/Inpaint":
start = max(0, min(split_index, B))
length = max(0, min(edge_frames, B - start))
end = start + length
frames_to_generate = length
before = source_clip[:start]
after = source_clip[end:]
mask = torch.cat([solid(before.shape[0], BLACK), solid(length, WHITE), solid(after.shape[0], BLACK)], dim=0)
control_frames = torch.cat([before, solid(length, GREY), after], dim=0)
return (mask, control_frames, safe(before), safe(source_clip[start:end]), safe(after), ph(), frames_to_generate)
raise ValueError(f"Unknown mode: {mode}")