Add VACE Merge Back node for splicing VACE output into original video

Adds a new node that reconstructs full-length video by splicing VACE
sampler output back into the original clip at the trim positions. Supports
optical flow, alpha, and hard-cut blending at context/generated seams.
Also adds trim_start/trim_end INT outputs to VACESourcePrep.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-19 22:43:07 +01:00
parent 6fa235f26c
commit 89fa3405cb
5 changed files with 336 additions and 13 deletions

186
merge_node.py Normal file
View File

@@ -0,0 +1,186 @@
import torch
import numpy as np
OPTICAL_FLOW_PRESETS = {
'fast': {'levels': 2, 'winsize': 11, 'iterations': 2, 'poly_n': 5, 'poly_sigma': 1.1},
'balanced': {'levels': 3, 'winsize': 15, 'iterations': 3, 'poly_n': 5, 'poly_sigma': 1.2},
'quality': {'levels': 5, 'winsize': 21, 'iterations': 5, 'poly_n': 7, 'poly_sigma': 1.5},
'max': {'levels': 7, 'winsize': 31, 'iterations': 10, 'poly_n': 7, 'poly_sigma': 1.5},
}
PASS_THROUGH_MODES = {"Edge Extend", "Frame Interpolation", "Keyframe", "Video Inpaint"}
def _count_leading_black(mask):
"""Count consecutive black (context) frames at the start of mask."""
count = 0
for i in range(mask.shape[0]):
if mask[i].max().item() < 0.01:
count += 1
else:
break
return count
def _count_trailing_black(mask):
"""Count consecutive black (context) frames at the end of mask."""
count = 0
for i in range(mask.shape[0] - 1, -1, -1):
if mask[i].max().item() < 0.01:
count += 1
else:
break
return count
def _alpha_blend(frame_a, frame_b, alpha):
"""Simple linear crossfade between two frames (H,W,3 tensors)."""
return frame_a * (1.0 - alpha) + frame_b * alpha
def _optical_flow_blend(frame_a, frame_b, alpha, preset):
"""Motion-compensated blend using Farneback optical flow."""
try:
import cv2
except ImportError:
return _alpha_blend(frame_a, frame_b, alpha)
params = OPTICAL_FLOW_PRESETS[preset]
arr_a = (frame_a.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
arr_b = (frame_b.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
gray_a = cv2.cvtColor(arr_a, cv2.COLOR_RGB2GRAY)
gray_b = cv2.cvtColor(arr_b, cv2.COLOR_RGB2GRAY)
flow = cv2.calcOpticalFlowFarneback(
gray_a, gray_b, None,
pyr_scale=0.5,
levels=params['levels'],
winsize=params['winsize'],
iterations=params['iterations'],
poly_n=params['poly_n'],
poly_sigma=params['poly_sigma'],
flags=0,
)
h, w = flow.shape[:2]
x_coords = np.tile(np.arange(w), (h, 1)).astype(np.float32)
y_coords = np.tile(np.arange(h), (w, 1)).T.astype(np.float32)
# Warp A forward by alpha * flow
flow_fwd = flow * alpha
warped_a = cv2.remap(
arr_a,
x_coords + flow_fwd[..., 0],
y_coords + flow_fwd[..., 1],
cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
# Warp B backward by -(1-alpha) * flow
flow_back = -flow * (1 - alpha)
warped_b = cv2.remap(
arr_b,
x_coords + flow_back[..., 0],
y_coords + flow_back[..., 1],
cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
result = cv2.addWeighted(warped_a, 1 - alpha, warped_b, alpha, 0)
return torch.from_numpy(result.astype(np.float32) / 255.0).to(frame_a.device)
class VACEMergeBack:
CATEGORY = "VACE Tools"
FUNCTION = "merge"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("merged_clip",)
OUTPUT_TOOLTIPS = (
"Full reconstructed video with VACE output spliced back into the original clip.",
)
DESCRIPTION = """VACE Merge Back — splices VACE sampler output back into the original full-length video.
Connect the original (untrimmed) clip, the VACE sampler output, the mask from VACE Mask Generator,
and the mode/trim_start/trim_end from VACE Source Prep. The node detects context zones from the mask
and blends at the seams where context meets generated frames.
Pass-through modes (Edge Extend, Frame Interpolation, Keyframe, Video Inpaint):
Returns vace_output as-is — the VACE output IS the final result.
Splice modes (End, Pre, Middle, Join, Bidirectional, Replace):
Reconstructs original[:trim_start] + vace_output + original[trim_end:]
with optional blending at the seams.
Blend methods:
none — Hard cut at seams (fastest)
alpha — Simple linear crossfade
optical_flow — Motion-compensated blend using Farneback dense optical flow"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_clip": ("IMAGE", {"description": "Full original video (before any trimming)."}),
"vace_output": ("IMAGE", {"description": "VACE sampler output."}),
"mask": ("IMAGE", {"description": "Mask from VACE Mask Generator — BLACK=context, WHITE=generated."}),
"mode": ("STRING", {"forceInput": True, "description": "Mode from VACE Source Prep."}),
"trim_start": ("INT", {"forceInput": True, "default": 0, "description": "Start of trimmed region in original."}),
"trim_end": ("INT", {"forceInput": True, "default": 0, "description": "End of trimmed region in original."}),
"blend_frames": ("INT", {"default": 4, "min": 0, "max": 100, "description": "Context frames to blend at each seam (0 = hard cut)."}),
"blend_method": (["optical_flow", "alpha", "none"], {"default": "optical_flow", "description": "Blending method at seams."}),
"of_preset": (["fast", "balanced", "quality", "max"], {"default": "balanced", "description": "Optical flow quality preset."}),
},
}
def merge(self, original_clip, vace_output, mask, mode, trim_start, trim_end, blend_frames, blend_method, of_preset):
# Pass-through modes: VACE output IS the final result
if mode in PASS_THROUGH_MODES:
return (vace_output,)
# Splice modes: reconstruct full video
V = vace_output.shape[0]
head = original_clip[:trim_start]
tail = original_clip[trim_end:]
result = torch.cat([head, vace_output, tail], dim=0)
if blend_method == "none" or blend_frames <= 0:
return (result,)
# Detect context zones from mask
left_ctx_len = _count_leading_black(mask)
right_ctx_len = _count_trailing_black(mask)
def blend_frame(orig, vace, alpha):
if blend_method == "optical_flow":
return _optical_flow_blend(orig, vace, alpha, of_preset)
return _alpha_blend(orig, vace, alpha)
# Blend at LEFT seam (context → generated transition)
bf_left = min(blend_frames, left_ctx_len)
for j in range(bf_left):
alpha = (j + 1) / (bf_left + 1)
orig_frame = original_clip[trim_start + j]
vace_frame = vace_output[j]
result[trim_start + j] = blend_frame(orig_frame, vace_frame, alpha)
# Blend at RIGHT seam (generated → context transition)
bf_right = min(blend_frames, right_ctx_len)
for j in range(bf_right):
alpha = 1.0 - (j + 1) / (bf_right + 1)
frame_idx = V - bf_right + j
orig_frame = original_clip[trim_end - bf_right + j]
vace_frame = vace_output[frame_idx]
result[trim_start + frame_idx] = blend_frame(orig_frame, vace_frame, alpha)
return (result,)
NODE_CLASS_MAPPINGS = {
"VACEMergeBack": VACEMergeBack,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VACEMergeBack": "VACE Merge Back",
}