Snap target_frames output to 4n+1 (1, 5, 9, …, 81, …) for VACE encode

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-20 21:54:52 +01:00
parent b10fd0691b
commit 5cbe1f1d6a

View File

@@ -15,6 +15,11 @@ VACE_MODES = [
] ]
def _snap_4n1(n):
"""Round up to nearest 4n+1 value (1, 5, 9, 13, ..., 77, 81, ...)."""
return ((n + 2) // 4) * 4 + 1
def _create_solid_batch(count, height, width, color_value, device="cpu"): def _create_solid_batch(count, height, width, color_value, device="cpu"):
"""Create a batch of solid-color frames (B, H, W, 3). Returns empty tensor if count <= 0.""" """Create a batch of solid-color frames (B, H, W, 3). Returns empty tensor if count <= 0."""
if count <= 0: if count <= 0:
@@ -30,7 +35,7 @@ class VACEMaskGenerator:
OUTPUT_TOOLTIPS = ( OUTPUT_TOOLTIPS = (
"Visual reference for VACE — source pixels where mask is black, grey (#7f7f7f) fill where mask is white.", "Visual reference for VACE — source pixels where mask is black, grey (#7f7f7f) fill where mask is white.",
"Mask sequence — black (0) = keep original, white (1) = generate. Per-frame for most modes; per-pixel for Video Inpaint.", "Mask sequence — black (0) = keep original, white (1) = generate. Per-frame for most modes; per-pixel for Video Inpaint.",
"Total frame count of the output sequence — wire directly to VACE encode.", "Total frame count snapped to 4n+1 (1, 5, 9, …, 81, …) — wire directly to VACE encode.",
) )
DESCRIPTION = """VACE Mask Generator — builds mask + control_frames sequences for all VACE generation modes. DESCRIPTION = """VACE Mask Generator — builds mask + control_frames sequences for all VACE generation modes.
@@ -121,6 +126,7 @@ If your source is longer, use VACE Source Prep upstream to trim it first."""
def generate(self, source_clip, mode, target_frames, split_index, edge_frames, inpaint_mask=None, keyframe_positions=None): def generate(self, source_clip, mode, target_frames, split_index, edge_frames, inpaint_mask=None, keyframe_positions=None):
B, H, W, C = source_clip.shape B, H, W, C = source_clip.shape
dev = source_clip.device dev = source_clip.device
target_frames = _snap_4n1(target_frames)
modes_using_target = {"End Extend", "Pre Extend", "Middle Extend", "Edge Extend", modes_using_target = {"End Extend", "Pre Extend", "Middle Extend", "Edge Extend",
"Join Extend", "Bidirectional Extend", "Keyframe"} "Join Extend", "Bidirectional Extend", "Keyframe"}
@@ -215,7 +221,7 @@ If your source is longer, use VACE Source Prep upstream to trim it first."""
ctrl_parts.append(solid(step, GREY)) ctrl_parts.append(solid(step, GREY))
mask = torch.cat(mask_parts, dim=0) mask = torch.cat(mask_parts, dim=0)
control_frames = torch.cat(ctrl_parts, dim=0) control_frames = torch.cat(ctrl_parts, dim=0)
return (control_frames, mask, B + frames_to_generate) return (control_frames, mask, _snap_4n1(B + frames_to_generate))
elif mode == "Replace/Inpaint": elif mode == "Replace/Inpaint":
if split_index >= B: if split_index >= B:
@@ -231,7 +237,7 @@ If your source is longer, use VACE Source Prep upstream to trim it first."""
after = source_clip[end:] after = source_clip[end:]
mask = torch.cat([solid(before.shape[0], BLACK), solid(length, WHITE), solid(after.shape[0], BLACK)], dim=0) mask = torch.cat([solid(before.shape[0], BLACK), solid(length, WHITE), solid(after.shape[0], BLACK)], dim=0)
control_frames = torch.cat([before, solid(length, GREY), after], dim=0) control_frames = torch.cat([before, solid(length, GREY), after], dim=0)
return (control_frames, mask, B) return (control_frames, mask, _snap_4n1(B))
elif mode == "Video Inpaint": elif mode == "Video Inpaint":
if inpaint_mask is None: if inpaint_mask is None:
@@ -254,7 +260,7 @@ If your source is longer, use VACE Source Prep upstream to trim it first."""
mask = m3 mask = m3
grey = torch.full_like(source_clip, GREY) grey = torch.full_like(source_clip, GREY)
control_frames = source_clip * (1.0 - m3) + grey * m3 control_frames = source_clip * (1.0 - m3) + grey * m3
return (control_frames, mask, B) return (control_frames, mask, _snap_4n1(B))
elif mode == "Keyframe": elif mode == "Keyframe":
if B > target_frames: if B > target_frames: