Add VACE Merge Back node for splicing VACE output into original video

Adds a new node that reconstructs full-length video by splicing VACE
sampler output back into the original clip at the trim positions. Supports
optical flow, alpha, and hard-cut blending at context/generated seams.
Also adds trim_start/trim_end INT outputs to VACESourcePrep.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-19 22:43:07 +01:00
parent 6fa235f26c
commit 89fa3405cb
5 changed files with 336 additions and 13 deletions

View File

@@ -41,6 +41,8 @@ Irrelevant widgets are automatically hidden based on the selected mode.
| `segment_1``segment_4` | IMAGE | Frame segments per mode (same meaning as mask generator segments). Unused segments are 1-frame black placeholders. |
| `inpaint_mask` | MASK | Trimmed to match output, or placeholder. |
| `keyframe_positions` | STRING | Pass-through. |
| `trim_start` | INT | Start index of the trimmed region in the original clip — wire to VACE Merge Back. |
| `trim_end` | INT | End index of the trimmed region in the original clip — wire to VACE Merge Back. |
### Per-Mode Trimming
@@ -314,6 +316,72 @@ control_frames: [ k0][ GREY ][ k1][ GREY ][ k2][ GREY ][ k3]
---
## Node: VACE Merge Back
Splices VACE sampler output back into the original full-length video. Connect the original (untrimmed) clip, the VACE sampler output, the mask from VACE Mask Generator, and the `mode`/`trim_start`/`trim_end` from VACE Source Prep.
Irrelevant widgets are automatically hidden based on the selected blend method.
### Inputs
| Input | Type | Default | Description |
|---|---|---|---|
| `original_clip` | IMAGE | — | Full original video (before any trimming). |
| `vace_output` | IMAGE | — | VACE sampler output. |
| `mask` | IMAGE | — | Mask from VACE Mask Generator — BLACK=context, WHITE=generated. |
| `mode` | STRING | *(wired)* | Mode from VACE Source Prep (must be wired, not typed). |
| `trim_start` | INT | *(wired)* | Start of trimmed region in original (from VACE Source Prep). |
| `trim_end` | INT | *(wired)* | End of trimmed region in original (from VACE Source Prep). |
| `blend_frames` | INT | `4` | Context frames to blend at each seam (0 = hard cut). |
| `blend_method` | ENUM | `optical_flow` | `none` (hard cut), `alpha` (linear crossfade), or `optical_flow` (motion-compensated). |
| `of_preset` | ENUM | `balanced` | Optical flow quality: `fast`, `balanced`, `quality`, `max`. |
### Outputs
| Output | Type | Description |
|---|---|---|
| `merged_clip` | IMAGE | Full reconstructed video. |
### Behavior
**Pass-through modes** (Edge Extend, Frame Interpolation, Keyframe, Video Inpaint): returns `vace_output` as-is — the VACE output IS the final result for these modes.
**Splice modes** (End, Pre, Middle, Join, Bidirectional, Replace): reconstructs `original[:trim_start] + vace_output + original[trim_end:]`, then blends at the seams where context frames meet original frames.
The node detects context zones by counting consecutive black frames at the start and end of the mask. At each seam, `blend_frames` frames are blended with a smooth alpha ramp. Optical flow blending warps both frames along the motion field before blending, reducing ghosting on moving subjects.
### Example: Middle Extend
```
Original: 274 frames (0273)
Prep: split_index=137, input_left=16, input_right=16
→ trim_start=121, trim_end=153, trimmed=32 frames
Mask Gen: target_frames=81
→ mask = [BLACK×16] [WHITE×49] [BLACK×16]
VACE out: 81 frames (from sampler)
Merge: result = original[0:121] + vace[0:81] + original[153:274]
→ 121 + 81 + 121 = 323 frames
Left blend: vace[0..3] ↔ original[121..124]
Right blend: vace[77..80] ↔ original[149..152]
```
### Wiring Diagram
```
[Load Video]
├─ source_clip ──→ [VACESourcePrep] ─┬─ source_clip ──→ [MaskGen] ─→ mask ──┐
│ ├─ mode ───────────────────────────────┤
│ ├─ trim_start ─────────────────────────┤
│ └─ trim_end ──────────────────────────┤
│ │
└─ original_clip ───────────────────────────────────────────────────────────→ [VACEMergeBack]
[Sampler] ─→ vace_output ────────────────┘
```
---
## Node: WanVideo Save Merged Model
Saves a WanVideo diffusion model (with merged LoRAs) as a `.safetensors` file. Found under the **WanVideoWrapper** category.
@@ -379,4 +447,5 @@ Loads a LATENT from an absolute file path. Found under the **latent** category.
## Dependencies
PyTorch and safetensors, both bundled with ComfyUI.
- **PyTorch** and **safetensors** bundled with ComfyUI.
- **OpenCV** (`cv2`) — optional, for optical flow blending in VACE Merge Back. Falls back to alpha blending if unavailable.

View File

@@ -7,11 +7,17 @@ from .latent_node import (
NODE_CLASS_MAPPINGS as LATENT_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as LATENT_DISPLAY_MAPPINGS,
)
from .merge_node import (
NODE_CLASS_MAPPINGS as MERGE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as MERGE_DISPLAY_MAPPINGS,
)
NODE_CLASS_MAPPINGS.update(SAVE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(LATENT_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(MERGE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(SAVE_DISPLAY_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(LATENT_DISPLAY_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(MERGE_DISPLAY_MAPPINGS)
WEB_DIRECTORY = "./web/js"

186
merge_node.py Normal file
View File

@@ -0,0 +1,186 @@
import torch
import numpy as np
OPTICAL_FLOW_PRESETS = {
'fast': {'levels': 2, 'winsize': 11, 'iterations': 2, 'poly_n': 5, 'poly_sigma': 1.1},
'balanced': {'levels': 3, 'winsize': 15, 'iterations': 3, 'poly_n': 5, 'poly_sigma': 1.2},
'quality': {'levels': 5, 'winsize': 21, 'iterations': 5, 'poly_n': 7, 'poly_sigma': 1.5},
'max': {'levels': 7, 'winsize': 31, 'iterations': 10, 'poly_n': 7, 'poly_sigma': 1.5},
}
PASS_THROUGH_MODES = {"Edge Extend", "Frame Interpolation", "Keyframe", "Video Inpaint"}
def _count_leading_black(mask):
"""Count consecutive black (context) frames at the start of mask."""
count = 0
for i in range(mask.shape[0]):
if mask[i].max().item() < 0.01:
count += 1
else:
break
return count
def _count_trailing_black(mask):
"""Count consecutive black (context) frames at the end of mask."""
count = 0
for i in range(mask.shape[0] - 1, -1, -1):
if mask[i].max().item() < 0.01:
count += 1
else:
break
return count
def _alpha_blend(frame_a, frame_b, alpha):
"""Simple linear crossfade between two frames (H,W,3 tensors)."""
return frame_a * (1.0 - alpha) + frame_b * alpha
def _optical_flow_blend(frame_a, frame_b, alpha, preset):
"""Motion-compensated blend using Farneback optical flow."""
try:
import cv2
except ImportError:
return _alpha_blend(frame_a, frame_b, alpha)
params = OPTICAL_FLOW_PRESETS[preset]
arr_a = (frame_a.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
arr_b = (frame_b.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
gray_a = cv2.cvtColor(arr_a, cv2.COLOR_RGB2GRAY)
gray_b = cv2.cvtColor(arr_b, cv2.COLOR_RGB2GRAY)
flow = cv2.calcOpticalFlowFarneback(
gray_a, gray_b, None,
pyr_scale=0.5,
levels=params['levels'],
winsize=params['winsize'],
iterations=params['iterations'],
poly_n=params['poly_n'],
poly_sigma=params['poly_sigma'],
flags=0,
)
h, w = flow.shape[:2]
x_coords = np.tile(np.arange(w), (h, 1)).astype(np.float32)
y_coords = np.tile(np.arange(h), (w, 1)).T.astype(np.float32)
# Warp A forward by alpha * flow
flow_fwd = flow * alpha
warped_a = cv2.remap(
arr_a,
x_coords + flow_fwd[..., 0],
y_coords + flow_fwd[..., 1],
cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
# Warp B backward by -(1-alpha) * flow
flow_back = -flow * (1 - alpha)
warped_b = cv2.remap(
arr_b,
x_coords + flow_back[..., 0],
y_coords + flow_back[..., 1],
cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
result = cv2.addWeighted(warped_a, 1 - alpha, warped_b, alpha, 0)
return torch.from_numpy(result.astype(np.float32) / 255.0).to(frame_a.device)
class VACEMergeBack:
CATEGORY = "VACE Tools"
FUNCTION = "merge"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("merged_clip",)
OUTPUT_TOOLTIPS = (
"Full reconstructed video with VACE output spliced back into the original clip.",
)
DESCRIPTION = """VACE Merge Back — splices VACE sampler output back into the original full-length video.
Connect the original (untrimmed) clip, the VACE sampler output, the mask from VACE Mask Generator,
and the mode/trim_start/trim_end from VACE Source Prep. The node detects context zones from the mask
and blends at the seams where context meets generated frames.
Pass-through modes (Edge Extend, Frame Interpolation, Keyframe, Video Inpaint):
Returns vace_output as-is — the VACE output IS the final result.
Splice modes (End, Pre, Middle, Join, Bidirectional, Replace):
Reconstructs original[:trim_start] + vace_output + original[trim_end:]
with optional blending at the seams.
Blend methods:
none — Hard cut at seams (fastest)
alpha — Simple linear crossfade
optical_flow — Motion-compensated blend using Farneback dense optical flow"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_clip": ("IMAGE", {"description": "Full original video (before any trimming)."}),
"vace_output": ("IMAGE", {"description": "VACE sampler output."}),
"mask": ("IMAGE", {"description": "Mask from VACE Mask Generator — BLACK=context, WHITE=generated."}),
"mode": ("STRING", {"forceInput": True, "description": "Mode from VACE Source Prep."}),
"trim_start": ("INT", {"forceInput": True, "default": 0, "description": "Start of trimmed region in original."}),
"trim_end": ("INT", {"forceInput": True, "default": 0, "description": "End of trimmed region in original."}),
"blend_frames": ("INT", {"default": 4, "min": 0, "max": 100, "description": "Context frames to blend at each seam (0 = hard cut)."}),
"blend_method": (["optical_flow", "alpha", "none"], {"default": "optical_flow", "description": "Blending method at seams."}),
"of_preset": (["fast", "balanced", "quality", "max"], {"default": "balanced", "description": "Optical flow quality preset."}),
},
}
def merge(self, original_clip, vace_output, mask, mode, trim_start, trim_end, blend_frames, blend_method, of_preset):
# Pass-through modes: VACE output IS the final result
if mode in PASS_THROUGH_MODES:
return (vace_output,)
# Splice modes: reconstruct full video
V = vace_output.shape[0]
head = original_clip[:trim_start]
tail = original_clip[trim_end:]
result = torch.cat([head, vace_output, tail], dim=0)
if blend_method == "none" or blend_frames <= 0:
return (result,)
# Detect context zones from mask
left_ctx_len = _count_leading_black(mask)
right_ctx_len = _count_trailing_black(mask)
def blend_frame(orig, vace, alpha):
if blend_method == "optical_flow":
return _optical_flow_blend(orig, vace, alpha, of_preset)
return _alpha_blend(orig, vace, alpha)
# Blend at LEFT seam (context → generated transition)
bf_left = min(blend_frames, left_ctx_len)
for j in range(bf_left):
alpha = (j + 1) / (bf_left + 1)
orig_frame = original_clip[trim_start + j]
vace_frame = vace_output[j]
result[trim_start + j] = blend_frame(orig_frame, vace_frame, alpha)
# Blend at RIGHT seam (generated → context transition)
bf_right = min(blend_frames, right_ctx_len)
for j in range(bf_right):
alpha = 1.0 - (j + 1) / (bf_right + 1)
frame_idx = V - bf_right + j
orig_frame = original_clip[trim_end - bf_right + j]
vace_frame = vace_output[frame_idx]
result[trim_start + frame_idx] = blend_frame(orig_frame, vace_frame, alpha)
return (result,)
NODE_CLASS_MAPPINGS = {
"VACEMergeBack": VACEMergeBack,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VACEMergeBack": "VACE Merge Back",
}

View File

@@ -330,11 +330,11 @@ If your source is longer, use VACE Source Prep upstream to trim it first."""
class VACESourcePrep:
CATEGORY = "VACE Tools"
FUNCTION = "prepare"
RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "MASK", "STRING")
RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "MASK", "STRING", "INT", "INT")
RETURN_NAMES = (
"source_clip", "mode", "split_index", "edge_frames",
"segment_1", "segment_2", "segment_3", "segment_4",
"inpaint_mask", "keyframe_positions",
"inpaint_mask", "keyframe_positions", "trim_start", "trim_end",
)
OUTPUT_TOOLTIPS = (
"Trimmed source frames — wire to VACE Mask Generator's source_clip.",
@@ -347,6 +347,8 @@ class VACESourcePrep:
"Segment 4: Join: part 4. Others: placeholder.",
"Inpaint mask trimmed to match output — wire to VACE Mask Generator.",
"Keyframe positions pass-through — wire to VACE Mask Generator.",
"Start index of the trimmed region in the original clip — wire to VACE Merge Back.",
"End index of the trimmed region in the original clip — wire to VACE Merge Back.",
)
DESCRIPTION = """VACE Source Prep — trims long source clips for VACE Mask Generator.
@@ -483,7 +485,7 @@ input_left / input_right (0 = use all available):
else:
output = source_clip
start = 0
return (output, mode, 0, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out)
return (output, mode, 0, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out, start, B)
elif mode == "Pre Extend":
if input_right > 0:
@@ -492,7 +494,7 @@ input_left / input_right (0 = use all available):
else:
output = source_clip
end = B
return (output, mode, output.shape[0], edge_frames, safe(output), ph(), ph(), ph(), trim_mask(0, end), kp_out)
return (output, mode, output.shape[0], edge_frames, safe(output), ph(), ph(), ph(), trim_mask(0, end), kp_out, 0, end)
elif mode == "Middle Extend":
left_start = max(0, split_index - input_left) if input_left > 0 else 0
@@ -501,7 +503,7 @@ input_left / input_right (0 = use all available):
out_split = split_index - left_start
part_a = source_clip[left_start:split_index]
part_b = source_clip[split_index:right_end]
return (output, mode, out_split, edge_frames, safe(part_a), safe(part_b), ph(), ph(), trim_mask(left_start, right_end), kp_out)
return (output, mode, out_split, edge_frames, safe(part_a), safe(part_b), ph(), ph(), trim_mask(left_start, right_end), kp_out, left_start, right_end)
elif mode == "Edge Extend":
eff_left = min(input_left if input_left > 0 else edge_frames, B)
@@ -511,7 +513,7 @@ input_left / input_right (0 = use all available):
end_seg = source_clip[-sym:] if sym > 0 else source_clip[:0]
mid_seg = source_clip[sym:B - sym] if 2 * sym < B else source_clip[:0]
output = torch.cat([start_seg, end_seg], dim=0)
return (output, mode, 0, sym, safe(start_seg), safe(mid_seg), safe(end_seg), ph(), mask_ph(), kp_out)
return (output, mode, 0, sym, safe(start_seg), safe(mid_seg), safe(end_seg), ph(), mask_ph(), kp_out, 0, B)
elif mode == "Join Extend":
half = B // 2
@@ -527,7 +529,7 @@ input_left / input_right (0 = use all available):
part_3 = second_half[:sym]
part_4 = second_half[sym:]
output = torch.cat([part_2, part_3], dim=0)
return (output, mode, 0, sym, safe(part_1), safe(part_2), safe(part_3), safe(part_4), mask_ph(), kp_out)
return (output, mode, 0, sym, safe(part_1), safe(part_2), safe(part_3), safe(part_4), mask_ph(), kp_out, half - sym, half + sym)
elif mode == "Bidirectional Extend":
if input_left > 0:
@@ -536,10 +538,10 @@ input_left / input_right (0 = use all available):
else:
output = source_clip
start = 0
return (output, mode, split_index, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out)
return (output, mode, split_index, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out, start, B)
elif mode == "Frame Interpolation":
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), trim_mask(0, B), kp_out)
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), trim_mask(0, B), kp_out, 0, B)
elif mode == "Replace/Inpaint":
start = max(0, min(split_index, B))
@@ -553,14 +555,14 @@ input_left / input_right (0 = use all available):
output = torch.cat([before, replace_region, after], dim=0)
out_split = before.shape[0]
out_edge = length
return (output, mode, out_split, out_edge, safe(before), safe(replace_region), safe(after), ph(), trim_mask(ctx_start, ctx_end), kp_out)
return (output, mode, out_split, out_edge, safe(before), safe(replace_region), safe(after), ph(), trim_mask(ctx_start, ctx_end), kp_out, ctx_start, ctx_end)
elif mode == "Video Inpaint":
out_mask = inpaint_mask.to(dev) if inpaint_mask is not None else mask_ph()
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), out_mask, kp_out)
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), out_mask, kp_out, 0, B)
elif mode == "Keyframe":
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), mask_ph(), kp_out)
return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), mask_ph(), kp_out, 0, B)
raise ValueError(f"Unknown mode: {mode}")

View File

@@ -76,3 +76,63 @@ app.registerExtension({
updateVisibility(modeWidget.value);
},
});
app.registerExtension({
name: "VACE.MergeBack.SmartDisplay",
nodeCreated(node) {
if (node.comfyClass !== "VACEMergeBack") return;
const methodWidget = node.widgets.find(w => w.name === "blend_method");
if (!methodWidget) return;
function toggleWidget(widget, show) {
if (!widget) return;
if (!widget._origType) widget._origType = widget.type;
widget.type = show ? widget._origType : "hidden";
}
function updateVisibility(method) {
const showBlend = method !== "none";
const showOf = method === "optical_flow";
toggleWidget(node.widgets.find(w => w.name === "blend_frames"), showBlend);
toggleWidget(node.widgets.find(w => w.name === "of_preset"), showOf);
node.setSize(node.computeSize());
app.graph.setDirtyCanvas(true);
}
const descriptor = Object.getOwnPropertyDescriptor(methodWidget, "value") ||
{ configurable: true };
const hasCustomAccessor = !!descriptor.get;
if (!hasCustomAccessor) {
let _value = methodWidget.value;
Object.defineProperty(methodWidget, "value", {
get() { return _value; },
set(v) {
_value = v;
updateVisibility(v);
},
configurable: true,
});
} else {
const origGet = descriptor.get;
const origSet = descriptor.set;
Object.defineProperty(methodWidget, "value", {
get() { return origGet.call(this); },
set(v) {
origSet.call(this, v);
updateVisibility(v);
},
configurable: true,
});
}
const origCallback = methodWidget.callback;
methodWidget.callback = function(value) {
updateVisibility(value);
if (origCallback) origCallback.call(this, value);
};
updateVisibility(methodWidget.value);
},
});