Add VACE Merge Back node for splicing VACE output into original video

Adds a new node that reconstructs full-length video by splicing VACE
sampler output back into the original clip at the trim positions. Supports
optical flow, alpha, and hard-cut blending at context/generated seams.
Also adds trim_start/trim_end INT outputs to VACESourcePrep.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-19 22:43:07 +01:00
parent 6fa235f26c
commit 89fa3405cb
5 changed files with 336 additions and 13 deletions

View File

@@ -330,11 +330,11 @@ If your source is longer, use VACE Source Prep upstream to trim it first."""
class VACESourcePrep:
CATEGORY = "VACE Tools"
FUNCTION = "prepare"
RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "MASK", "STRING")
RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "MASK", "STRING", "INT", "INT")
RETURN_NAMES = (
"source_clip", "mode", "split_index", "edge_frames",
"segment_1", "segment_2", "segment_3", "segment_4",
"inpaint_mask", "keyframe_positions",
"inpaint_mask", "keyframe_positions", "trim_start", "trim_end",
)
OUTPUT_TOOLTIPS = (
"Trimmed source frames — wire to VACE Mask Generator's source_clip.",
@@ -347,6 +347,8 @@ class VACESourcePrep:
"Segment 4: Join: part 4. Others: placeholder.",
"Inpaint mask trimmed to match output — wire to VACE Mask Generator.",
"Keyframe positions pass-through — wire to VACE Mask Generator.",
"Start index of the trimmed region in the original clip — wire to VACE Merge Back.",
"End index of the trimmed region in the original clip — wire to VACE Merge Back.",
)
DESCRIPTION = """VACE Source Prep — trims long source clips for VACE Mask Generator.
@@ -483,7 +485,7 @@ input_left / input_right (0 = use all available):
else:
output = source_clip
start = 0
return (output, mode, 0, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out)
return (output, mode, 0, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out, start, B)
elif mode == "Pre Extend":
if input_right > 0:
@@ -492,7 +494,7 @@ input_left / input_right (0 = use all available):
else:
output = source_clip
end = B
return (output, mode, output.shape[0], edge_frames, safe(output), ph(), ph(), ph(), trim_mask(0, end), kp_out)
return (output, mode, output.shape[0], edge_frames, safe(output), ph(), ph(), ph(), trim_mask(0, end), kp_out, 0, end)
elif mode == "Middle Extend":
left_start = max(0, split_index - input_left) if input_left > 0 else 0
@@ -501,7 +503,7 @@ input_left / input_right (0 = use all available):
out_split = split_index - left_start
part_a = source_clip[left_start:split_index]
part_b = source_clip[split_index:right_end]
return (output, mode, out_split, edge_frames, safe(part_a), safe(part_b), ph(), ph(), trim_mask(left_start, right_end), kp_out)
return (output, mode, out_split, edge_frames, safe(part_a), safe(part_b), ph(), ph(), trim_mask(left_start, right_end), kp_out, left_start, right_end)
elif mode == "Edge Extend":
eff_left = min(input_left if input_left > 0 else edge_frames, B)
@@ -511,7 +513,7 @@ input_left / input_right (0 = use all available):
end_seg = source_clip[-sym:] if sym > 0 else source_clip[:0]
mid_seg = source_clip[sym:B - sym] if 2 * sym < B else source_clip[:0]
output = torch.cat([start_seg, end_seg], dim=0)
return (output, mode, 0, sym, safe(start_seg), safe(mid_seg), safe(end_seg), ph(), mask_ph(), kp_out)
return (output, mode, 0, sym, safe(start_seg), safe(mid_seg), safe(end_seg), ph(), mask_ph(), kp_out, 0, B)
elif mode == "Join Extend":
half = B // 2
@@ -527,7 +529,7 @@ input_left / input_right (0 = use all available):
part_3 = second_half[:sym]
part_4 = second_half[sym:]
output = torch.cat([part_2, part_3], dim=0)
return (output, mode, 0, sym, safe(part_1), safe(part_2), safe(part_3), safe(part_4), mask_ph(), kp_out)
return (output, mode, 0, sym, safe(part_1), safe(part_2), safe(part_3), safe(part_4), mask_ph(), kp_out, half - sym, half + sym)
elif mode == "Bidirectional Extend":
if input_left > 0:
@@ -536,10 +538,10 @@ input_left / input_right (0 = use all available):
else:
output = source_clip
start = 0
return (output, mode, split_index, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out)
return (output, mode, split_index, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out, start, B)
elif mode == "Frame Interpolation":
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), trim_mask(0, B), kp_out)
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), trim_mask(0, B), kp_out, 0, B)
elif mode == "Replace/Inpaint":
start = max(0, min(split_index, B))
@@ -553,14 +555,14 @@ input_left / input_right (0 = use all available):
output = torch.cat([before, replace_region, after], dim=0)
out_split = before.shape[0]
out_edge = length
return (output, mode, out_split, out_edge, safe(before), safe(replace_region), safe(after), ph(), trim_mask(ctx_start, ctx_end), kp_out)
return (output, mode, out_split, out_edge, safe(before), safe(replace_region), safe(after), ph(), trim_mask(ctx_start, ctx_end), kp_out, ctx_start, ctx_end)
elif mode == "Video Inpaint":
out_mask = inpaint_mask.to(dev) if inpaint_mask is not None else mask_ph()
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), out_mask, kp_out)
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), out_mask, kp_out, 0, B)
elif mode == "Keyframe":
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), mask_ph(), kp_out)
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), mask_ph(), kp_out, 0, B)
raise ValueError(f"Unknown mode: {mode}")