Add Video Inpaint mode for per-pixel spatial mask regeneration

New 9th mode that works at the pixel level rather than the frame level.
Accepts an optional MASK input (B,H,W) to mark spatial regions for
regeneration, with single-frame broadcast, spatial dimension validation,
and contiguous output tensors.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-19 13:37:04 +01:00
parent d8143ea889
commit 87ec7b3938
2 changed files with 68 additions and 7 deletions

View File

@@ -58,10 +58,11 @@ class VACEMaskGenerator:
"Bidirectional Extend",
"Frame Interpolation",
"Replace/Inpaint",
"Video Inpaint",
],
{
"default": "End Extend",
"description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves. Bidirectional: generate before AND after clip. Frame Interpolation: insert generated frames between each source pair. Replace/Inpaint: regenerate a range of frames in-place.",
"description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves. Bidirectional: generate before AND after clip. Frame Interpolation: insert generated frames between each source pair. Replace/Inpaint: regenerate a range of frames in-place. Video Inpaint: regenerate masked spatial regions across all frames (requires inpaint_mask).",
},
),
"target_frames": (
@@ -70,7 +71,7 @@ class VACEMaskGenerator:
"default": 81,
"min": 1,
"max": 10000,
"description": "Total output frame count for mask and control_frames. Unused by Frame Interpolation and Replace/Inpaint.",
"description": "Total output frame count for mask and control_frames. Unused by Frame Interpolation, Replace/Inpaint, and Video Inpaint.",
},
),
"split_index": (
@@ -91,10 +92,18 @@ class VACEMaskGenerator:
"description": "Number of edge frames to use for Edge and Join modes. Unused by End/Pre/Middle. Replace/Inpaint: number of frames to replace.",
},
),
}
},
"optional": {
"inpaint_mask": (
"MASK",
{
"description": "Spatial inpaint mask for Video Inpaint mode. White (1.0) = regenerate, Black (0.0) = keep. Single frame broadcasts to all source frames.",
},
),
},
}
def generate(self, source_clip, mode, target_frames, split_index, edge_frames):
def generate(self, source_clip, mode, target_frames, split_index, edge_frames, inpaint_mask=None):
B, H, W, C = source_clip.shape
dev = source_clip.device
BLACK = 0.0
@@ -199,6 +208,30 @@ class VACEMaskGenerator:
control_frames = torch.cat([before, solid(length, GREY), after], dim=0)
return (mask, control_frames, safe(before), safe(source_clip[start:end]), safe(after), ph(), frames_to_generate)
elif mode == "Video Inpaint":
if inpaint_mask is None:
raise ValueError("Video Inpaint mode requires the inpaint_mask input to be connected.")
m = inpaint_mask.to(dev) # (Bm, Hm, Wm) MASK type
if m.shape[1] != H or m.shape[2] != W:
raise ValueError(
f"Video Inpaint: inpaint_mask spatial size {m.shape[1]}x{m.shape[2]} "
f"doesn't match source_clip {H}x{W}."
)
m = m.clamp(0.0, 1.0)
if m.shape[0] == 1 and B > 1:
m = m.expand(B, -1, -1) # broadcast single mask to all frames
elif m.shape[0] != B:
raise ValueError(
f"Video Inpaint: inpaint_mask has {m.shape[0]} frames but source_clip has {B}. "
"Must match or be 1 frame."
)
m3 = m.unsqueeze(-1).expand(-1, -1, -1, 3).contiguous() # (B,H,W) -> (B,H,W,3)
mask = m3
grey = torch.full_like(source_clip, GREY)
control_frames = source_clip * (1.0 - m3) + grey * m3
frames_to_generate = B
return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate)
raise ValueError(f"Unknown mode: {mode}")