diff --git a/README.md b/README.md index a65ba41..1f86167 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,11 @@ Restart ComfyUI. The node appears under the **VACE Tools** category. | Input | Type | Default | Description | |---|---|---|---| | `source_clip` | IMAGE | — | Source video frames (B, H, W, C tensor) | -| `mode` | ENUM | `End Extend` | Generation mode (see below). 8 modes available. | -| `target_frames` | INT | `81` | Total output frame count for mask and control_frames (1–10000). Unused by Frame Interpolation and Replace/Inpaint. | +| `mode` | ENUM | `End Extend` | Generation mode (see below). 9 modes available. | +| `target_frames` | INT | `81` | Total output frame count for mask and control_frames (1–10000). Unused by Frame Interpolation, Replace/Inpaint, and Video Inpaint. | | `split_index` | INT | `0` | Where to split the source. Meaning varies by mode. Unused by Edge/Join. Bidirectional: frames before clip (0 = even split). Frame Interpolation: new frames per gap. Replace/Inpaint: start index of replace region. | -| `edge_frames` | INT | `8` | Number of edge frames for Edge and Join modes. Replace/Inpaint: number of frames to replace. Unused by End/Pre/Middle/Bidirectional/Frame Interpolation. | +| `edge_frames` | INT | `8` | Number of edge frames for Edge and Join modes. Replace/Inpaint: number of frames to replace. Unused by End/Pre/Middle/Bidirectional/Frame Interpolation/Video Inpaint. | +| `inpaint_mask` | MASK | *(optional)* | Spatial inpaint mask for Video Inpaint mode (B, H, W). White (1.0) = regenerate, Black (0.0) = keep. Single frame broadcasts to all source frames. | ### Outputs @@ -206,6 +207,33 @@ control_frames: [ before frames ][ GREY × replace ][ after frames ] | `segment_3` | After — source[start+length:] | | `segment_4` | Placeholder | +--- + +### Video Inpaint + +Regenerate **spatial regions** within frames using a per-pixel mask. Unlike other modes that work at the frame level (entire frames kept or generated), Video Inpaint operates at the pixel level — masked regions are regenerated while the rest of each frame is preserved. + +- **`inpaint_mask`** *(required)* — a `MASK` (B, H, W) where white (1.0) marks regions to regenerate and black (0.0) marks regions to keep. A single-frame mask is automatically broadcast to all source frames; a multi-frame mask must have the same frame count as `source_clip`. +- **`target_frames`**, **`split_index`**, **`edge_frames`** — unused. +- **`frames_to_generate`** = `source_frames` (all frames are partially regenerated). +- **Total output** = `source_frames` (same length — in-place spatial replacement). + +Compositing formula per pixel: + +``` +control_frames = source × (1 − mask) + grey × mask +``` + +``` +mask: [ per-pixel mask broadcast to (B, H, W, 3) ] +control_frames: [ source pixels where mask=0, grey where mask=1 ] +``` + +| Segment | Content | +|---|---| +| `segment_1` | Full source clip | +| `segment_2`–`4` | Placeholder | + ## Dependencies None beyond PyTorch, which is bundled with ComfyUI. diff --git a/nodes.py b/nodes.py index cd3afb6..beb4b38 100644 --- a/nodes.py +++ b/nodes.py @@ -58,10 +58,11 @@ class VACEMaskGenerator: "Bidirectional Extend", "Frame Interpolation", "Replace/Inpaint", + "Video Inpaint", ], { "default": "End Extend", - "description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves. Bidirectional: generate before AND after clip. Frame Interpolation: insert generated frames between each source pair. Replace/Inpaint: regenerate a range of frames in-place.", + "description": "End: generate after clip. Pre: generate before clip. Middle: generate at split point. Edge: generate between reversed edges (looping). Join: generate to heal two halves. Bidirectional: generate before AND after clip. Frame Interpolation: insert generated frames between each source pair. Replace/Inpaint: regenerate a range of frames in-place. Video Inpaint: regenerate masked spatial regions across all frames (requires inpaint_mask).", }, ), "target_frames": ( @@ -70,7 +71,7 @@ class VACEMaskGenerator: "default": 81, "min": 1, "max": 10000, - "description": "Total output frame count for mask and control_frames. Unused by Frame Interpolation and Replace/Inpaint.", + "description": "Total output frame count for mask and control_frames. Unused by Frame Interpolation, Replace/Inpaint, and Video Inpaint.", }, ), "split_index": ( @@ -91,10 +92,18 @@ class VACEMaskGenerator: "description": "Number of edge frames to use for Edge and Join modes. Unused by End/Pre/Middle. Replace/Inpaint: number of frames to replace.", }, ), - } + }, + "optional": { + "inpaint_mask": ( + "MASK", + { + "description": "Spatial inpaint mask for Video Inpaint mode. White (1.0) = regenerate, Black (0.0) = keep. Single frame broadcasts to all source frames.", + }, + ), + }, } - def generate(self, source_clip, mode, target_frames, split_index, edge_frames): + def generate(self, source_clip, mode, target_frames, split_index, edge_frames, inpaint_mask=None): B, H, W, C = source_clip.shape dev = source_clip.device BLACK = 0.0 @@ -199,6 +208,30 @@ class VACEMaskGenerator: control_frames = torch.cat([before, solid(length, GREY), after], dim=0) return (mask, control_frames, safe(before), safe(source_clip[start:end]), safe(after), ph(), frames_to_generate) + elif mode == "Video Inpaint": + if inpaint_mask is None: + raise ValueError("Video Inpaint mode requires the inpaint_mask input to be connected.") + m = inpaint_mask.to(dev) # (Bm, Hm, Wm) MASK type + if m.shape[1] != H or m.shape[2] != W: + raise ValueError( + f"Video Inpaint: inpaint_mask spatial size {m.shape[1]}x{m.shape[2]} " + f"doesn't match source_clip {H}x{W}." + ) + m = m.clamp(0.0, 1.0) + if m.shape[0] == 1 and B > 1: + m = m.expand(B, -1, -1) # broadcast single mask to all frames + elif m.shape[0] != B: + raise ValueError( + f"Video Inpaint: inpaint_mask has {m.shape[0]} frames but source_clip has {B}. " + "Must match or be 1 frame." + ) + m3 = m.unsqueeze(-1).expand(-1, -1, -1, 3).contiguous() # (B,H,W) -> (B,H,W,3) + mask = m3 + grey = torch.full_like(source_clip, GREY) + control_frames = source_clip * (1.0 - m3) + grey * m3 + frames_to_generate = B + return (mask, control_frames, source_clip, ph(), ph(), ph(), frames_to_generate) + raise ValueError(f"Unknown mode: {mode}")