Replace torch.cat with pre-allocated tensor + in-place copy. Clone only small context slices for blending instead of holding full source/vace tensors during the blend loop. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
187 lines
7.4 KiB
Python
187 lines
7.4 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
|
|
OPTICAL_FLOW_PRESETS = {
|
|
'fast': {'levels': 2, 'winsize': 11, 'iterations': 2, 'poly_n': 5, 'poly_sigma': 1.1},
|
|
'balanced': {'levels': 3, 'winsize': 15, 'iterations': 3, 'poly_n': 5, 'poly_sigma': 1.2},
|
|
'quality': {'levels': 5, 'winsize': 21, 'iterations': 5, 'poly_n': 7, 'poly_sigma': 1.5},
|
|
'max': {'levels': 7, 'winsize': 31, 'iterations': 10, 'poly_n': 7, 'poly_sigma': 1.5},
|
|
}
|
|
|
|
PASS_THROUGH_MODES = {"Edge Extend", "Frame Interpolation", "Keyframe", "Video Inpaint"}
|
|
|
|
|
|
|
|
def _alpha_blend(frame_a, frame_b, alpha):
|
|
"""Simple linear crossfade between two frames (H,W,3 tensors)."""
|
|
return frame_a * (1.0 - alpha) + frame_b * alpha
|
|
|
|
|
|
def _optical_flow_blend(frame_a, frame_b, alpha, preset):
|
|
"""Motion-compensated blend using Farneback optical flow."""
|
|
try:
|
|
import cv2
|
|
except ImportError:
|
|
return _alpha_blend(frame_a, frame_b, alpha)
|
|
|
|
params = OPTICAL_FLOW_PRESETS[preset]
|
|
|
|
arr_a = (frame_a.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
arr_b = (frame_b.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
gray_a = cv2.cvtColor(arr_a, cv2.COLOR_RGB2GRAY)
|
|
gray_b = cv2.cvtColor(arr_b, cv2.COLOR_RGB2GRAY)
|
|
flow = cv2.calcOpticalFlowFarneback(
|
|
gray_a, gray_b, None,
|
|
pyr_scale=0.5,
|
|
levels=params['levels'],
|
|
winsize=params['winsize'],
|
|
iterations=params['iterations'],
|
|
poly_n=params['poly_n'],
|
|
poly_sigma=params['poly_sigma'],
|
|
flags=0,
|
|
)
|
|
|
|
h, w = flow.shape[:2]
|
|
x_coords = np.tile(np.arange(w), (h, 1)).astype(np.float32)
|
|
y_coords = np.tile(np.arange(h), (w, 1)).T.astype(np.float32)
|
|
|
|
# Warp A forward by alpha * flow
|
|
flow_fwd = flow * alpha
|
|
warped_a = cv2.remap(
|
|
arr_a,
|
|
x_coords + flow_fwd[..., 0],
|
|
y_coords + flow_fwd[..., 1],
|
|
cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_REPLICATE,
|
|
)
|
|
|
|
# Warp B backward by -(1-alpha) * flow
|
|
flow_back = -flow * (1 - alpha)
|
|
warped_b = cv2.remap(
|
|
arr_b,
|
|
x_coords + flow_back[..., 0],
|
|
y_coords + flow_back[..., 1],
|
|
cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_REPLICATE,
|
|
)
|
|
|
|
result = cv2.addWeighted(warped_a, 1 - alpha, warped_b, alpha, 0)
|
|
return torch.from_numpy(result.astype(np.float32) / 255.0).to(frame_a.device)
|
|
|
|
|
|
class VACEMergeBack:
|
|
CATEGORY = "VACE Tools"
|
|
FUNCTION = "merge"
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("merged_clip",)
|
|
OUTPUT_TOOLTIPS = (
|
|
"Full reconstructed video with VACE output spliced back into the original clip.",
|
|
)
|
|
DESCRIPTION = """VACE Merge Back — splices VACE sampler output back into the original full-length video.
|
|
|
|
Connect the original (untrimmed) clip, the VACE sampler output, and the vace_pipe from VACE Source Prep.
|
|
The pipe carries mode, trim bounds, and context frame counts for automatic blending.
|
|
|
|
Pass-through modes (Edge Extend, Frame Interpolation, Keyframe, Video Inpaint):
|
|
Returns vace_output as-is — the VACE output IS the final result.
|
|
|
|
Splice modes (End, Pre, Middle, Join, Bidirectional, Replace):
|
|
Reconstructs original[:trim_start] + vace_output + original[trim_end:]
|
|
with automatic blending across the full context zones.
|
|
For two-clip Join Extend, connect source_clip_2 — the tail comes from the second clip.
|
|
|
|
Blend methods:
|
|
none — Hard cut at seams (fastest)
|
|
alpha — Simple linear crossfade
|
|
optical_flow — Motion-compensated blend using Farneback dense optical flow"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"source_clip": ("IMAGE", {"description": "Full original video (before any trimming)."}),
|
|
"vace_output": ("IMAGE", {"description": "VACE sampler output."}),
|
|
"vace_pipe": ("VACE_PIPE", {"description": "Pipe from VACE Source Prep carrying mode, trim bounds, and context counts."}),
|
|
"blend_method": (["optical_flow", "alpha", "none"], {"default": "optical_flow", "description": "Blending method at seams."}),
|
|
"of_preset": (["fast", "balanced", "quality", "max"], {"default": "balanced", "description": "Optical flow quality preset."}),
|
|
},
|
|
"optional": {
|
|
"source_clip_2": ("IMAGE", {"description": "Second original clip for Join Extend with two separate clips."}),
|
|
},
|
|
}
|
|
|
|
def merge(self, source_clip, vace_output, vace_pipe, blend_method, of_preset, source_clip_2=None):
|
|
mode = vace_pipe["mode"]
|
|
trim_start = vace_pipe["trim_start"]
|
|
trim_end = vace_pipe["trim_end"]
|
|
left_ctx = vace_pipe["left_ctx"]
|
|
right_ctx = vace_pipe["right_ctx"]
|
|
|
|
# Pass-through modes: VACE output IS the final result
|
|
if mode in PASS_THROUGH_MODES:
|
|
return (vace_output,)
|
|
|
|
# Splice modes: reconstruct full video
|
|
two_clip = vace_pipe.get("two_clip", False)
|
|
V = vace_output.shape[0]
|
|
tail_src = source_clip_2 if (two_clip and source_clip_2 is not None) else source_clip
|
|
tail_len = max(0, tail_src.shape[0] - trim_end)
|
|
total = trim_start + V + tail_len
|
|
need_blend = blend_method != "none" and (left_ctx > 0 or right_ctx > 0)
|
|
|
|
# Clone only the small context slices needed for blending
|
|
# so the full source/vace tensors aren't held during the blend loop
|
|
left_orig_ctx = left_vace_ctx = right_orig_ctx = right_vace_ctx = None
|
|
if need_blend:
|
|
if left_ctx > 0:
|
|
n = min(left_ctx, max(0, source_clip.shape[0] - trim_start))
|
|
if n > 0:
|
|
left_orig_ctx = source_clip[trim_start:trim_start + n].clone()
|
|
left_vace_ctx = vace_output[:n].clone()
|
|
if right_ctx > 0:
|
|
rs = trim_end - right_ctx
|
|
n = min(right_ctx, max(0, tail_src.shape[0] - rs))
|
|
if n > 0:
|
|
right_orig_ctx = tail_src[rs:rs + n].clone()
|
|
right_vace_ctx = vace_output[V - right_ctx:V - right_ctx + n].clone()
|
|
|
|
# Pre-allocate and copy in-place (avoids torch.cat allocation overhead)
|
|
result = torch.empty((total,) + source_clip.shape[1:], dtype=source_clip.dtype, device=source_clip.device)
|
|
result[:trim_start] = source_clip[:trim_start]
|
|
result[trim_start:trim_start + V] = vace_output
|
|
if tail_len > 0:
|
|
result[trim_start + V:] = tail_src[trim_end:]
|
|
|
|
if not need_blend:
|
|
return (result,)
|
|
|
|
def blend_frame(orig, vace, alpha):
|
|
if blend_method == "optical_flow":
|
|
return _optical_flow_blend(orig, vace, alpha, of_preset)
|
|
return _alpha_blend(orig, vace, alpha)
|
|
|
|
# Blend using saved context slices (not the full tensors)
|
|
if left_orig_ctx is not None:
|
|
for j in range(left_orig_ctx.shape[0]):
|
|
alpha = (j + 1) / (left_ctx + 1)
|
|
result[trim_start + j] = blend_frame(left_orig_ctx[j], left_vace_ctx[j], alpha)
|
|
|
|
if right_orig_ctx is not None:
|
|
for j in range(right_orig_ctx.shape[0]):
|
|
alpha = 1.0 - (j + 1) / (right_ctx + 1)
|
|
frame_idx = V - right_ctx + j
|
|
result[trim_start + frame_idx] = blend_frame(right_orig_ctx[j], right_vace_ctx[j], alpha)
|
|
|
|
return (result,)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"VACEMergeBack": VACEMergeBack,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"VACEMergeBack": "VACE Merge Back",
|
|
}
|