Optimize merge node RAM usage with pre-allocation and context slicing

Replace torch.cat with pre-allocated tensor + in-place copy.
Clone only small context slices for blending instead of holding
full source/vace tensors during the blend loop.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-21 16:51:24 +01:00
parent 0f17311199
commit 615599cdfd

View File

@@ -126,16 +126,35 @@ Blend methods:
# Splice modes: reconstruct full video # Splice modes: reconstruct full video
two_clip = vace_pipe.get("two_clip", False) two_clip = vace_pipe.get("two_clip", False)
V = vace_output.shape[0] V = vace_output.shape[0]
head = source_clip[:trim_start] tail_src = source_clip_2 if (two_clip and source_clip_2 is not None) else source_clip
if two_clip and source_clip_2 is not None: tail_len = max(0, tail_src.shape[0] - trim_end)
tail = source_clip_2[trim_end:] total = trim_start + V + tail_len
right_orig = source_clip_2 need_blend = blend_method != "none" and (left_ctx > 0 or right_ctx > 0)
else:
tail = source_clip[trim_end:]
right_orig = source_clip
result = torch.cat([head, vace_output, tail], dim=0)
if blend_method == "none" or (left_ctx == 0 and right_ctx == 0): # Clone only the small context slices needed for blending
# so the full source/vace tensors aren't held during the blend loop
left_orig_ctx = left_vace_ctx = right_orig_ctx = right_vace_ctx = None
if need_blend:
if left_ctx > 0:
n = min(left_ctx, max(0, source_clip.shape[0] - trim_start))
if n > 0:
left_orig_ctx = source_clip[trim_start:trim_start + n].clone()
left_vace_ctx = vace_output[:n].clone()
if right_ctx > 0:
rs = trim_end - right_ctx
n = min(right_ctx, max(0, tail_src.shape[0] - rs))
if n > 0:
right_orig_ctx = tail_src[rs:rs + n].clone()
right_vace_ctx = vace_output[V - right_ctx:V - right_ctx + n].clone()
# Pre-allocate and copy in-place (avoids torch.cat allocation overhead)
result = torch.empty((total,) + source_clip.shape[1:], dtype=source_clip.dtype, device=source_clip.device)
result[:trim_start] = source_clip[:trim_start]
result[trim_start:trim_start + V] = vace_output
if tail_len > 0:
result[trim_start + V:] = tail_src[trim_end:]
if not need_blend:
return (result,) return (result,)
def blend_frame(orig, vace, alpha): def blend_frame(orig, vace, alpha):
@@ -143,20 +162,17 @@ Blend methods:
return _optical_flow_blend(orig, vace, alpha, of_preset) return _optical_flow_blend(orig, vace, alpha, of_preset)
return _alpha_blend(orig, vace, alpha) return _alpha_blend(orig, vace, alpha)
# Blend across full left context zone # Blend using saved context slices (not the full tensors)
for j in range(left_ctx): if left_orig_ctx is not None:
if trim_start + j >= source_clip.shape[0]: for j in range(left_orig_ctx.shape[0]):
break alpha = (j + 1) / (left_ctx + 1)
alpha = (j + 1) / (left_ctx + 1) result[trim_start + j] = blend_frame(left_orig_ctx[j], left_vace_ctx[j], alpha)
result[trim_start + j] = blend_frame(source_clip[trim_start + j], vace_output[j], alpha)
# Blend across full right context zone if right_orig_ctx is not None:
for j in range(right_ctx): for j in range(right_orig_ctx.shape[0]):
if trim_end - right_ctx + j >= right_orig.shape[0]: alpha = 1.0 - (j + 1) / (right_ctx + 1)
break frame_idx = V - right_ctx + j
alpha = 1.0 - (j + 1) / (right_ctx + 1) result[trim_start + frame_idx] = blend_frame(right_orig_ctx[j], right_vace_ctx[j], alpha)
frame_idx = V - right_ctx + j
result[trim_start + frame_idx] = blend_frame(right_orig[trim_end - right_ctx + j], vace_output[frame_idx], alpha)
return (result,) return (result,)