From 89fa3405cbe4798bae828a5508f58a89edee5732 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 19 Feb 2026 22:43:07 +0100 Subject: [PATCH] Add VACE Merge Back node for splicing VACE output into original video Adds a new node that reconstructs full-length video by splicing VACE sampler output back into the original clip at the trim positions. Supports optical flow, alpha, and hard-cut blending at context/generated seams. Also adds trim_start/trim_end INT outputs to VACESourcePrep. Co-Authored-By: Claude Opus 4.6 --- README.md | 71 +++++++++++++++- __init__.py | 6 ++ merge_node.py | 186 +++++++++++++++++++++++++++++++++++++++++ nodes.py | 26 +++--- web/js/vace_widgets.js | 60 +++++++++++++ 5 files changed, 336 insertions(+), 13 deletions(-) create mode 100644 merge_node.py diff --git a/README.md b/README.md index fd948b2..88e876f 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,8 @@ Irrelevant widgets are automatically hidden based on the selected mode. | `segment_1`–`segment_4` | IMAGE | Frame segments per mode (same meaning as mask generator segments). Unused segments are 1-frame black placeholders. | | `inpaint_mask` | MASK | Trimmed to match output, or placeholder. | | `keyframe_positions` | STRING | Pass-through. | +| `trim_start` | INT | Start index of the trimmed region in the original clip — wire to VACE Merge Back. | +| `trim_end` | INT | End index of the trimmed region in the original clip — wire to VACE Merge Back. | ### Per-Mode Trimming @@ -314,6 +316,72 @@ control_frames: [ k0][ GREY ][ k1][ GREY ][ k2][ GREY ][ k3] --- +## Node: VACE Merge Back + +Splices VACE sampler output back into the original full-length video. Connect the original (untrimmed) clip, the VACE sampler output, the mask from VACE Mask Generator, and the `mode`/`trim_start`/`trim_end` from VACE Source Prep. + +Irrelevant widgets are automatically hidden based on the selected blend method. + +### Inputs + +| Input | Type | Default | Description | +|---|---|---|---| +| `original_clip` | IMAGE | — | Full original video (before any trimming). | +| `vace_output` | IMAGE | — | VACE sampler output. | +| `mask` | IMAGE | — | Mask from VACE Mask Generator — BLACK=context, WHITE=generated. | +| `mode` | STRING | *(wired)* | Mode from VACE Source Prep (must be wired, not typed). | +| `trim_start` | INT | *(wired)* | Start of trimmed region in original (from VACE Source Prep). | +| `trim_end` | INT | *(wired)* | End of trimmed region in original (from VACE Source Prep). | +| `blend_frames` | INT | `4` | Context frames to blend at each seam (0 = hard cut). | +| `blend_method` | ENUM | `optical_flow` | `none` (hard cut), `alpha` (linear crossfade), or `optical_flow` (motion-compensated). | +| `of_preset` | ENUM | `balanced` | Optical flow quality: `fast`, `balanced`, `quality`, `max`. | + +### Outputs + +| Output | Type | Description | +|---|---|---| +| `merged_clip` | IMAGE | Full reconstructed video. | + +### Behavior + +**Pass-through modes** (Edge Extend, Frame Interpolation, Keyframe, Video Inpaint): returns `vace_output` as-is — the VACE output IS the final result for these modes. + +**Splice modes** (End, Pre, Middle, Join, Bidirectional, Replace): reconstructs `original[:trim_start] + vace_output + original[trim_end:]`, then blends at the seams where context frames meet original frames. + +The node detects context zones by counting consecutive black frames at the start and end of the mask. At each seam, `blend_frames` frames are blended with a smooth alpha ramp. Optical flow blending warps both frames along the motion field before blending, reducing ghosting on moving subjects. + +### Example: Middle Extend + +``` +Original: 274 frames (0–273) +Prep: split_index=137, input_left=16, input_right=16 + → trim_start=121, trim_end=153, trimmed=32 frames +Mask Gen: target_frames=81 + → mask = [BLACK×16] [WHITE×49] [BLACK×16] +VACE out: 81 frames (from sampler) +Merge: result = original[0:121] + vace[0:81] + original[153:274] + → 121 + 81 + 121 = 323 frames + Left blend: vace[0..3] ↔ original[121..124] + Right blend: vace[77..80] ↔ original[149..152] +``` + +### Wiring Diagram + +``` +[Load Video] + │ + ├─ source_clip ──→ [VACESourcePrep] ─┬─ source_clip ──→ [MaskGen] ─→ mask ──┐ + │ ├─ mode ───────────────────────────────┤ + │ ├─ trim_start ─────────────────────────┤ + │ └─ trim_end ──────────────────────────┤ + │ │ + └─ original_clip ───────────────────────────────────────────────────────────→ [VACEMergeBack] + │ + [Sampler] ─→ vace_output ────────────────┘ +``` + +--- + ## Node: WanVideo Save Merged Model Saves a WanVideo diffusion model (with merged LoRAs) as a `.safetensors` file. Found under the **WanVideoWrapper** category. @@ -379,4 +447,5 @@ Loads a LATENT from an absolute file path. Found under the **latent** category. ## Dependencies -PyTorch and safetensors, both bundled with ComfyUI. +- **PyTorch** and **safetensors** — bundled with ComfyUI. +- **OpenCV** (`cv2`) — optional, for optical flow blending in VACE Merge Back. Falls back to alpha blending if unavailable. diff --git a/__init__.py b/__init__.py index c4bb0eb..f361624 100644 --- a/__init__.py +++ b/__init__.py @@ -7,11 +7,17 @@ from .latent_node import ( NODE_CLASS_MAPPINGS as LATENT_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as LATENT_DISPLAY_MAPPINGS, ) +from .merge_node import ( + NODE_CLASS_MAPPINGS as MERGE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS as MERGE_DISPLAY_MAPPINGS, +) NODE_CLASS_MAPPINGS.update(SAVE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(LATENT_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(MERGE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(SAVE_DISPLAY_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(LATENT_DISPLAY_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(MERGE_DISPLAY_MAPPINGS) WEB_DIRECTORY = "./web/js" diff --git a/merge_node.py b/merge_node.py new file mode 100644 index 0000000..c81b86d --- /dev/null +++ b/merge_node.py @@ -0,0 +1,186 @@ +import torch +import numpy as np + + +OPTICAL_FLOW_PRESETS = { + 'fast': {'levels': 2, 'winsize': 11, 'iterations': 2, 'poly_n': 5, 'poly_sigma': 1.1}, + 'balanced': {'levels': 3, 'winsize': 15, 'iterations': 3, 'poly_n': 5, 'poly_sigma': 1.2}, + 'quality': {'levels': 5, 'winsize': 21, 'iterations': 5, 'poly_n': 7, 'poly_sigma': 1.5}, + 'max': {'levels': 7, 'winsize': 31, 'iterations': 10, 'poly_n': 7, 'poly_sigma': 1.5}, +} + +PASS_THROUGH_MODES = {"Edge Extend", "Frame Interpolation", "Keyframe", "Video Inpaint"} + + +def _count_leading_black(mask): + """Count consecutive black (context) frames at the start of mask.""" + count = 0 + for i in range(mask.shape[0]): + if mask[i].max().item() < 0.01: + count += 1 + else: + break + return count + + +def _count_trailing_black(mask): + """Count consecutive black (context) frames at the end of mask.""" + count = 0 + for i in range(mask.shape[0] - 1, -1, -1): + if mask[i].max().item() < 0.01: + count += 1 + else: + break + return count + + +def _alpha_blend(frame_a, frame_b, alpha): + """Simple linear crossfade between two frames (H,W,3 tensors).""" + return frame_a * (1.0 - alpha) + frame_b * alpha + + +def _optical_flow_blend(frame_a, frame_b, alpha, preset): + """Motion-compensated blend using Farneback optical flow.""" + try: + import cv2 + except ImportError: + return _alpha_blend(frame_a, frame_b, alpha) + + params = OPTICAL_FLOW_PRESETS[preset] + + arr_a = (frame_a.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + arr_b = (frame_b.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + + gray_a = cv2.cvtColor(arr_a, cv2.COLOR_RGB2GRAY) + gray_b = cv2.cvtColor(arr_b, cv2.COLOR_RGB2GRAY) + flow = cv2.calcOpticalFlowFarneback( + gray_a, gray_b, None, + pyr_scale=0.5, + levels=params['levels'], + winsize=params['winsize'], + iterations=params['iterations'], + poly_n=params['poly_n'], + poly_sigma=params['poly_sigma'], + flags=0, + ) + + h, w = flow.shape[:2] + x_coords = np.tile(np.arange(w), (h, 1)).astype(np.float32) + y_coords = np.tile(np.arange(h), (w, 1)).T.astype(np.float32) + + # Warp A forward by alpha * flow + flow_fwd = flow * alpha + warped_a = cv2.remap( + arr_a, + x_coords + flow_fwd[..., 0], + y_coords + flow_fwd[..., 1], + cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE, + ) + + # Warp B backward by -(1-alpha) * flow + flow_back = -flow * (1 - alpha) + warped_b = cv2.remap( + arr_b, + x_coords + flow_back[..., 0], + y_coords + flow_back[..., 1], + cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE, + ) + + result = cv2.addWeighted(warped_a, 1 - alpha, warped_b, alpha, 0) + return torch.from_numpy(result.astype(np.float32) / 255.0).to(frame_a.device) + + +class VACEMergeBack: + CATEGORY = "VACE Tools" + FUNCTION = "merge" + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("merged_clip",) + OUTPUT_TOOLTIPS = ( + "Full reconstructed video with VACE output spliced back into the original clip.", + ) + DESCRIPTION = """VACE Merge Back — splices VACE sampler output back into the original full-length video. + +Connect the original (untrimmed) clip, the VACE sampler output, the mask from VACE Mask Generator, +and the mode/trim_start/trim_end from VACE Source Prep. The node detects context zones from the mask +and blends at the seams where context meets generated frames. + +Pass-through modes (Edge Extend, Frame Interpolation, Keyframe, Video Inpaint): + Returns vace_output as-is — the VACE output IS the final result. + +Splice modes (End, Pre, Middle, Join, Bidirectional, Replace): + Reconstructs original[:trim_start] + vace_output + original[trim_end:] + with optional blending at the seams. + +Blend methods: + none — Hard cut at seams (fastest) + alpha — Simple linear crossfade + optical_flow — Motion-compensated blend using Farneback dense optical flow""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "original_clip": ("IMAGE", {"description": "Full original video (before any trimming)."}), + "vace_output": ("IMAGE", {"description": "VACE sampler output."}), + "mask": ("IMAGE", {"description": "Mask from VACE Mask Generator — BLACK=context, WHITE=generated."}), + "mode": ("STRING", {"forceInput": True, "description": "Mode from VACE Source Prep."}), + "trim_start": ("INT", {"forceInput": True, "default": 0, "description": "Start of trimmed region in original."}), + "trim_end": ("INT", {"forceInput": True, "default": 0, "description": "End of trimmed region in original."}), + "blend_frames": ("INT", {"default": 4, "min": 0, "max": 100, "description": "Context frames to blend at each seam (0 = hard cut)."}), + "blend_method": (["optical_flow", "alpha", "none"], {"default": "optical_flow", "description": "Blending method at seams."}), + "of_preset": (["fast", "balanced", "quality", "max"], {"default": "balanced", "description": "Optical flow quality preset."}), + }, + } + + def merge(self, original_clip, vace_output, mask, mode, trim_start, trim_end, blend_frames, blend_method, of_preset): + # Pass-through modes: VACE output IS the final result + if mode in PASS_THROUGH_MODES: + return (vace_output,) + + # Splice modes: reconstruct full video + V = vace_output.shape[0] + head = original_clip[:trim_start] + tail = original_clip[trim_end:] + result = torch.cat([head, vace_output, tail], dim=0) + + if blend_method == "none" or blend_frames <= 0: + return (result,) + + # Detect context zones from mask + left_ctx_len = _count_leading_black(mask) + right_ctx_len = _count_trailing_black(mask) + + def blend_frame(orig, vace, alpha): + if blend_method == "optical_flow": + return _optical_flow_blend(orig, vace, alpha, of_preset) + return _alpha_blend(orig, vace, alpha) + + # Blend at LEFT seam (context → generated transition) + bf_left = min(blend_frames, left_ctx_len) + for j in range(bf_left): + alpha = (j + 1) / (bf_left + 1) + orig_frame = original_clip[trim_start + j] + vace_frame = vace_output[j] + result[trim_start + j] = blend_frame(orig_frame, vace_frame, alpha) + + # Blend at RIGHT seam (generated → context transition) + bf_right = min(blend_frames, right_ctx_len) + for j in range(bf_right): + alpha = 1.0 - (j + 1) / (bf_right + 1) + frame_idx = V - bf_right + j + orig_frame = original_clip[trim_end - bf_right + j] + vace_frame = vace_output[frame_idx] + result[trim_start + frame_idx] = blend_frame(orig_frame, vace_frame, alpha) + + return (result,) + + +NODE_CLASS_MAPPINGS = { + "VACEMergeBack": VACEMergeBack, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "VACEMergeBack": "VACE Merge Back", +} diff --git a/nodes.py b/nodes.py index 1467a96..018a535 100644 --- a/nodes.py +++ b/nodes.py @@ -330,11 +330,11 @@ If your source is longer, use VACE Source Prep upstream to trim it first.""" class VACESourcePrep: CATEGORY = "VACE Tools" FUNCTION = "prepare" - RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "MASK", "STRING") + RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "MASK", "STRING", "INT", "INT") RETURN_NAMES = ( "source_clip", "mode", "split_index", "edge_frames", "segment_1", "segment_2", "segment_3", "segment_4", - "inpaint_mask", "keyframe_positions", + "inpaint_mask", "keyframe_positions", "trim_start", "trim_end", ) OUTPUT_TOOLTIPS = ( "Trimmed source frames — wire to VACE Mask Generator's source_clip.", @@ -347,6 +347,8 @@ class VACESourcePrep: "Segment 4: Join: part 4. Others: placeholder.", "Inpaint mask trimmed to match output — wire to VACE Mask Generator.", "Keyframe positions pass-through — wire to VACE Mask Generator.", + "Start index of the trimmed region in the original clip — wire to VACE Merge Back.", + "End index of the trimmed region in the original clip — wire to VACE Merge Back.", ) DESCRIPTION = """VACE Source Prep — trims long source clips for VACE Mask Generator. @@ -483,7 +485,7 @@ input_left / input_right (0 = use all available): else: output = source_clip start = 0 - return (output, mode, 0, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out) + return (output, mode, 0, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out, start, B) elif mode == "Pre Extend": if input_right > 0: @@ -492,7 +494,7 @@ input_left / input_right (0 = use all available): else: output = source_clip end = B - return (output, mode, output.shape[0], edge_frames, safe(output), ph(), ph(), ph(), trim_mask(0, end), kp_out) + return (output, mode, output.shape[0], edge_frames, safe(output), ph(), ph(), ph(), trim_mask(0, end), kp_out, 0, end) elif mode == "Middle Extend": left_start = max(0, split_index - input_left) if input_left > 0 else 0 @@ -501,7 +503,7 @@ input_left / input_right (0 = use all available): out_split = split_index - left_start part_a = source_clip[left_start:split_index] part_b = source_clip[split_index:right_end] - return (output, mode, out_split, edge_frames, safe(part_a), safe(part_b), ph(), ph(), trim_mask(left_start, right_end), kp_out) + return (output, mode, out_split, edge_frames, safe(part_a), safe(part_b), ph(), ph(), trim_mask(left_start, right_end), kp_out, left_start, right_end) elif mode == "Edge Extend": eff_left = min(input_left if input_left > 0 else edge_frames, B) @@ -511,7 +513,7 @@ input_left / input_right (0 = use all available): end_seg = source_clip[-sym:] if sym > 0 else source_clip[:0] mid_seg = source_clip[sym:B - sym] if 2 * sym < B else source_clip[:0] output = torch.cat([start_seg, end_seg], dim=0) - return (output, mode, 0, sym, safe(start_seg), safe(mid_seg), safe(end_seg), ph(), mask_ph(), kp_out) + return (output, mode, 0, sym, safe(start_seg), safe(mid_seg), safe(end_seg), ph(), mask_ph(), kp_out, 0, B) elif mode == "Join Extend": half = B // 2 @@ -527,7 +529,7 @@ input_left / input_right (0 = use all available): part_3 = second_half[:sym] part_4 = second_half[sym:] output = torch.cat([part_2, part_3], dim=0) - return (output, mode, 0, sym, safe(part_1), safe(part_2), safe(part_3), safe(part_4), mask_ph(), kp_out) + return (output, mode, 0, sym, safe(part_1), safe(part_2), safe(part_3), safe(part_4), mask_ph(), kp_out, half - sym, half + sym) elif mode == "Bidirectional Extend": if input_left > 0: @@ -536,10 +538,10 @@ input_left / input_right (0 = use all available): else: output = source_clip start = 0 - return (output, mode, split_index, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out) + return (output, mode, split_index, edge_frames, safe(output), ph(), ph(), ph(), trim_mask(start, B), kp_out, start, B) elif mode == "Frame Interpolation": - return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), trim_mask(0, B), kp_out) + return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), trim_mask(0, B), kp_out, 0, B) elif mode == "Replace/Inpaint": start = max(0, min(split_index, B)) @@ -553,14 +555,14 @@ input_left / input_right (0 = use all available): output = torch.cat([before, replace_region, after], dim=0) out_split = before.shape[0] out_edge = length - return (output, mode, out_split, out_edge, safe(before), safe(replace_region), safe(after), ph(), trim_mask(ctx_start, ctx_end), kp_out) + return (output, mode, out_split, out_edge, safe(before), safe(replace_region), safe(after), ph(), trim_mask(ctx_start, ctx_end), kp_out, ctx_start, ctx_end) elif mode == "Video Inpaint": out_mask = inpaint_mask.to(dev) if inpaint_mask is not None else mask_ph() - return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), out_mask, kp_out) + return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), out_mask, kp_out, 0, B) elif mode == "Keyframe": - return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), mask_ph(), kp_out) + return (source_clip, mode, split_index, edge_frames, safe(source_clip), ph(), ph(), ph(), mask_ph(), kp_out, 0, B) raise ValueError(f"Unknown mode: {mode}") diff --git a/web/js/vace_widgets.js b/web/js/vace_widgets.js index f924a0e..9710df2 100644 --- a/web/js/vace_widgets.js +++ b/web/js/vace_widgets.js @@ -76,3 +76,63 @@ app.registerExtension({ updateVisibility(modeWidget.value); }, }); + +app.registerExtension({ + name: "VACE.MergeBack.SmartDisplay", + nodeCreated(node) { + if (node.comfyClass !== "VACEMergeBack") return; + + const methodWidget = node.widgets.find(w => w.name === "blend_method"); + if (!methodWidget) return; + + function toggleWidget(widget, show) { + if (!widget) return; + if (!widget._origType) widget._origType = widget.type; + widget.type = show ? widget._origType : "hidden"; + } + + function updateVisibility(method) { + const showBlend = method !== "none"; + const showOf = method === "optical_flow"; + toggleWidget(node.widgets.find(w => w.name === "blend_frames"), showBlend); + toggleWidget(node.widgets.find(w => w.name === "of_preset"), showOf); + node.setSize(node.computeSize()); + app.graph.setDirtyCanvas(true); + } + + const descriptor = Object.getOwnPropertyDescriptor(methodWidget, "value") || + { configurable: true }; + const hasCustomAccessor = !!descriptor.get; + + if (!hasCustomAccessor) { + let _value = methodWidget.value; + Object.defineProperty(methodWidget, "value", { + get() { return _value; }, + set(v) { + _value = v; + updateVisibility(v); + }, + configurable: true, + }); + } else { + const origGet = descriptor.get; + const origSet = descriptor.set; + Object.defineProperty(methodWidget, "value", { + get() { return origGet.call(this); }, + set(v) { + origSet.call(this, v); + updateVisibility(v); + }, + configurable: true, + }); + } + + const origCallback = methodWidget.callback; + methodWidget.callback = function(value) { + updateVisibility(value); + if (origCallback) origCallback.call(this, value); + }; + + updateVisibility(methodWidget.value); + }, +});