Fix FlashVSR frame padding to match pipeline requirements

The pipeline requires num_frames % 4 == 1. Our old _pad_video_5d used a
wrong formula that produced non-conforming counts (e.g. 33 input → 35
padded → pipeline rounds to 37, wasting VRAM).

New padding uses num_frames % 8 == 1 (also satisfies % 4 == 1), which
ensures the streaming loop output exactly matches num_frames with zero
waste. Optimal input counts: 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105.

Also removes incorrect 2-frame warmup stripping from _restore_video_sequence
— the pipeline output doesn't have warmup artifacts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 16:20:02 +01:00
parent 5071c4de4f
commit 11e2acb9e0
2 changed files with 18 additions and 20 deletions

View File

@@ -728,30 +728,28 @@ class FlashVSRModel:
@staticmethod
def _pad_video_5d(video):
"""Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline.
"""Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25).
Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring
the pipeline's streaming loop gets correct iteration counts.
The pipeline requires num_frames % 4 == 1 and the streaming loop
produces 8*P+17 output frames where P = (num_frames-1)//8 - 2.
Using % 8 == 1 ensures output == num_frames (no wasted iterations).
Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ...
"""
tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1)
video = torch.cat([video, tail], dim=2)
added = 0
remainder = (video.shape[2] + 2 - 5) % 8
n = video.shape[2]
target = max(n, 25) # minimum for streaming loop (P >= 1)
remainder = (target - 1) % 8
if remainder != 0:
added = 8 - remainder
target += 8 - remainder
added = target - n
if added > 0:
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
video = torch.cat([video, pad], dim=2)
return video, added
@staticmethod
def _restore_video_sequence(result, added_frames, expected):
"""Strip padding and warmup frames from the output."""
if added_frames > 0 and result.shape[0] > added_frames:
result = result[:-added_frames]
# Strip the first 2 pipeline warmup frames
if result.shape[0] > 2:
result = result[2:]
# Adjust to exact expected count
"""Trim pipeline output to the expected frame count."""
if result.shape[0] > expected:
result = result[:expected]
elif result.shape[0] < expected:

View File

@@ -1731,14 +1731,14 @@ class FlashVSRUpscale:
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8
# Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25
total_steps = 0
for cs, ce in chunks:
padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d
remainder = (padded_n + 2 - 5) % 8
nf = max(ce - cs, 25)
remainder = (nf - 1) % 8
if remainder != 0:
padded_n += 8 - remainder
total_steps += max(1, (padded_n - 1) // 8 - 2)
nf += 8 - remainder
total_steps += max(1, (nf - 1) // 8 - 2)
pbar = ProgressBar(total_steps)
step_ref = [0]