Files
Comfyui-VACE-Tools/nodes.py
Ethanfel 87ec7b3938 Add Video Inpaint mode for per-pixel spatial mask regeneration
New 9th mode that works at the pixel level rather than the frame level.
Accepts an optional MASK input (B,H,W) to mark spatial regions for
regeneration, with single-frame broadcast, spatial dimension validation,
and contiguous output tensors.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:37:04 +01:00

245 lines
11 KiB
Python

import torch
def _create_solid_batch(count, height, width, color_value, device="cpu"):
"""Create a batch of solid-color frames (B, H, W, 3). Returns empty tensor if count <= 0."""
if count <= 0:
return torch.empty((0, height, width, 3), dtype=torch.float32, device=device)
return torch.full((count, height, width, 3), color_value, dtype=torch.float32, device=device)
def _placeholder(height, width, device="cpu"):
"""Create a single-frame black placeholder (1, H, W, 3)."""
return torch.zeros((1, height, width, 3), dtype=torch.float32, device=device)
def _ensure_nonempty(tensor, height, width, device="cpu"):
"""Replace a 0-frame tensor with a 1-frame black placeholder."""
if tensor.shape[0] == 0:
return _placeholder(height, width, device)
return tensor
class VACEMaskGenerator:
CATEGORY = "VACE Tools"
FUNCTION = "generate"
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "INT")
RETURN_NAMES = (
"mask",
"control_frames",
"segment_1",
"segment_2",
"segment_3",
"segment_4",
"frames_to_generate",
)
OUTPUT_TOOLTIPS = (
"Black/white mask sequence (target_frames long). Black = keep original, White = generate new.",
"Source frames composited with grey (#7f7f7f) fill (target_frames long). Fed to VACE as visual reference.",
"First clip segment. Contents depend on mode.",
"Second clip segment. Placeholder if unused by the current mode.",
"Third clip segment. Placeholder if unused by the current mode.",
"Fourth clip segment. Placeholder if unused by the current mode.",
"Number of new frames to generate (white/grey frames added).",
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"source_clip": ("IMAGE", {"description": "Source video frames (B,H,W,C tensor)."}),
"mode": (
[
"End Extend",
"Pre Extend",
"Middle Extend",
"Edge Extend",
"Join Extend",
"Bidirectional Extend",
"Frame Interpolation",
"Replace/Inpaint",
"Video Inpaint",
],
{
"default": "End Extend",
"description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves. Bidirectional: generate before AND after clip. Frame Interpolation: insert generated frames between each source pair. Replace/Inpaint: regenerate a range of frames in-place. Video Inpaint: regenerate masked spatial regions across all frames (requires inpaint_mask).",
},
),
"target_frames": (
"INT",
{
"default": 81,
"min": 1,
"max": 10000,
"description": "Total output frame count for mask and control_frames. Unused by Frame Interpolation, Replace/Inpaint, and Video Inpaint.",
},
),
"split_index": (
"INT",
{
"default": 0,
"min": -10000,
"max": 10000,
"description": "Where to split the source. End: trim from end (e.g. -16). Pre: reference frames from start (e.g. 24). Middle: split frame index. Unused by Edge/Join. Bidirectional: frames before clip (0 = even split). Frame Interpolation: new frames per gap. Replace/Inpaint: start index of replace region.",
},
),
"edge_frames": (
"INT",
{
"default": 8,
"min": 1,
"max": 10000,
"description": "Number of edge frames to use for Edge and Join modes. Unused by End/Pre/Middle. Replace/Inpaint: number of frames to replace.",
},
),
},
"optional": {
"inpaint_mask": (
"MASK",
{
"description": "Spatial inpaint mask for Video Inpaint mode. White (1.0) = regenerate, Black (0.0) = keep. Single frame broadcasts to all source frames.",
},
),
},
}
def generate(self, source_clip, mode, target_frames, split_index, edge_frames, inpaint_mask=None):
B, H, W, C = source_clip.shape
dev = source_clip.device
BLACK = 0.0
WHITE = 1.0
GREY = 0.498
def solid(count, color):
return _create_solid_batch(count, H, W, color, dev)
def ph():
return _placeholder(H, W, dev)
def safe(t):
return _ensure_nonempty(t, H, W, dev)
if mode == "End Extend":
frames_to_generate = target_frames - B
mask = torch.cat([solid(B, BLACK), solid(frames_to_generate, WHITE)], dim=0)
control_frames = torch.cat([source_clip, solid(frames_to_generate, GREY)], dim=0)
segment_1 = source_clip[:split_index] if split_index != 0 else source_clip
return (mask, control_frames, safe(segment_1), ph(), ph(), ph(), frames_to_generate)
elif mode == "Pre Extend":
image_a = source_clip[:split_index]
image_b = source_clip[split_index:]
a_count = image_a.shape[0]
frames_to_generate = target_frames - a_count
mask = torch.cat([solid(frames_to_generate, WHITE), solid(a_count, BLACK)], dim=0)
control_frames = torch.cat([solid(frames_to_generate, GREY), image_a], dim=0)
return (mask, control_frames, safe(image_b), ph(), ph(), ph(), frames_to_generate)
elif mode == "Middle Extend":
image_a = source_clip[:split_index]
image_b = source_clip[split_index:]
a_count = image_a.shape[0]
b_count = image_b.shape[0]
frames_to_generate = target_frames - (a_count + b_count)
mask = torch.cat([solid(a_count, BLACK), solid(frames_to_generate, WHITE), solid(b_count, BLACK)], dim=0)
control_frames = torch.cat([image_a, solid(frames_to_generate, GREY), image_b], dim=0)
return (mask, control_frames, safe(image_a), safe(image_b), ph(), ph(), frames_to_generate)
elif mode == "Edge Extend":
start_seg = source_clip[:edge_frames]
end_seg = source_clip[-edge_frames:]
mid_seg = source_clip[edge_frames:-edge_frames]
start_count = start_seg.shape[0]
end_count = end_seg.shape[0]
frames_to_generate = target_frames - (start_count + end_count)
mask = torch.cat([solid(end_count, BLACK), solid(frames_to_generate, WHITE), solid(start_count, BLACK)], dim=0)
control_frames = torch.cat([end_seg, solid(frames_to_generate, GREY), start_seg], dim=0)
return (mask, control_frames, start_seg, safe(mid_seg), end_seg, ph(), frames_to_generate)
elif mode == "Join Extend":
half = B // 2
first_half = source_clip[:half]
second_half = source_clip[half:]
part_1 = first_half[:-edge_frames]
part_2 = first_half[-edge_frames:]
part_3 = second_half[:edge_frames]
part_4 = second_half[edge_frames:]
p2_count = part_2.shape[0]
p3_count = part_3.shape[0]
frames_to_generate = target_frames - (p2_count + p3_count)
mask = torch.cat([solid(p2_count, BLACK), solid(frames_to_generate, WHITE), solid(p3_count, BLACK)], dim=0)
control_frames = torch.cat([part_2, solid(frames_to_generate, GREY), part_3], dim=0)
return (mask, control_frames, safe(part_1), safe(part_2), safe(part_3), safe(part_4), frames_to_generate)
elif mode == "Bidirectional Extend":
frames_to_generate = max(0, target_frames - B)
if split_index > 0:
pre_count = min(split_index, frames_to_generate)
else:
pre_count = frames_to_generate // 2
post_count = frames_to_generate - pre_count
mask = torch.cat([solid(pre_count, WHITE), solid(B, BLACK), solid(post_count, WHITE)], dim=0)
control_frames = torch.cat([solid(pre_count, GREY), source_clip, solid(post_count, GREY)], dim=0)
return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate)
elif mode == "Frame Interpolation":
step = max(split_index, 1)
frames_to_generate = (B - 1) * step
mask_parts = []
ctrl_parts = []
for i in range(B):
mask_parts.append(solid(1, BLACK))
ctrl_parts.append(source_clip[i:i+1])
if i < B - 1:
mask_parts.append(solid(step, WHITE))
ctrl_parts.append(solid(step, GREY))
mask = torch.cat(mask_parts, dim=0)
control_frames = torch.cat(ctrl_parts, dim=0)
return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate)
elif mode == "Replace/Inpaint":
start = max(0, min(split_index, B))
length = max(0, min(edge_frames, B - start))
end = start + length
frames_to_generate = length
before = source_clip[:start]
after = source_clip[end:]
mask = torch.cat([solid(before.shape[0], BLACK), solid(length, WHITE), solid(after.shape[0], BLACK)], dim=0)
control_frames = torch.cat([before, solid(length, GREY), after], dim=0)
return (mask, control_frames, safe(before), safe(source_clip[start:end]), safe(after), ph(), frames_to_generate)
elif mode == "Video Inpaint":
if inpaint_mask is None:
raise ValueError("Video Inpaint mode requires the inpaint_mask input to be connected.")
m = inpaint_mask.to(dev) # (Bm, Hm, Wm) MASK type
if m.shape[1] != H or m.shape[2] != W:
raise ValueError(
f"Video Inpaint: inpaint_mask spatial size {m.shape[1]}x{m.shape[2]} "
f"doesn't match source_clip {H}x{W}."
)
m = m.clamp(0.0, 1.0)
if m.shape[0] == 1 and B > 1:
m = m.expand(B, -1, -1) # broadcast single mask to all frames
elif m.shape[0] != B:
raise ValueError(
f"Video Inpaint: inpaint_mask has {m.shape[0]} frames but source_clip has {B}. "
"Must match or be 1 frame."
)
m3 = m.unsqueeze(-1).expand(-1, -1, -1, 3).contiguous() # (B,H,W) -> (B,H,W,3)
mask = m3
grey = torch.full_like(source_clip, GREY)
control_frames = source_clip * (1.0 - m3) + grey * m3
frames_to_generate = B
return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate)
raise ValueError(f"Unknown mode: {mode}")
NODE_CLASS_MAPPINGS = {
"VACEMaskGenerator": VACEMaskGenerator,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VACEMaskGenerator": "VACE Mask Generator",
}