Fix FlashVSR frame padding to match pipeline requirements
The pipeline requires num_frames % 4 == 1. Our old _pad_video_5d used a wrong formula that produced non-conforming counts (e.g. 33 input → 35 padded → pipeline rounds to 37, wasting VRAM). New padding uses num_frames % 8 == 1 (also satisfies % 4 == 1), which ensures the streaming loop output exactly matches num_frames with zero waste. Optimal input counts: 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105. Also removes incorrect 2-frame warmup stripping from _restore_video_sequence — the pipeline output doesn't have warmup artifacts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
28
inference.py
28
inference.py
@@ -728,30 +728,28 @@ class FlashVSRModel:
|
||||
|
||||
@staticmethod
|
||||
def _pad_video_5d(video):
|
||||
"""Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline.
|
||||
"""Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25).
|
||||
|
||||
Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring
|
||||
the pipeline's streaming loop gets correct iteration counts.
|
||||
The pipeline requires num_frames % 4 == 1 and the streaming loop
|
||||
produces 8*P+17 output frames where P = (num_frames-1)//8 - 2.
|
||||
Using % 8 == 1 ensures output == num_frames (no wasted iterations).
|
||||
|
||||
Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ...
|
||||
"""
|
||||
tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1)
|
||||
video = torch.cat([video, tail], dim=2)
|
||||
added = 0
|
||||
remainder = (video.shape[2] + 2 - 5) % 8
|
||||
n = video.shape[2]
|
||||
target = max(n, 25) # minimum for streaming loop (P >= 1)
|
||||
remainder = (target - 1) % 8
|
||||
if remainder != 0:
|
||||
added = 8 - remainder
|
||||
target += 8 - remainder
|
||||
added = target - n
|
||||
if added > 0:
|
||||
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
return video, added
|
||||
|
||||
@staticmethod
|
||||
def _restore_video_sequence(result, added_frames, expected):
|
||||
"""Strip padding and warmup frames from the output."""
|
||||
if added_frames > 0 and result.shape[0] > added_frames:
|
||||
result = result[:-added_frames]
|
||||
# Strip the first 2 pipeline warmup frames
|
||||
if result.shape[0] > 2:
|
||||
result = result[2:]
|
||||
# Adjust to exact expected count
|
||||
"""Trim pipeline output to the expected frame count."""
|
||||
if result.shape[0] > expected:
|
||||
result = result[:expected]
|
||||
elif result.shape[0] < expected:
|
||||
|
||||
10
nodes.py
10
nodes.py
@@ -1731,14 +1731,14 @@ class FlashVSRUpscale:
|
||||
chunks.append((prev_start, last_end))
|
||||
|
||||
# Estimate total pipeline steps for progress bar
|
||||
# Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8
|
||||
# Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25
|
||||
total_steps = 0
|
||||
for cs, ce in chunks:
|
||||
padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d
|
||||
remainder = (padded_n + 2 - 5) % 8
|
||||
nf = max(ce - cs, 25)
|
||||
remainder = (nf - 1) % 8
|
||||
if remainder != 0:
|
||||
padded_n += 8 - remainder
|
||||
total_steps += max(1, (padded_n - 1) // 8 - 2)
|
||||
nf += 8 - remainder
|
||||
total_steps += max(1, (nf - 1) // 8 - 2)
|
||||
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
|
||||
Reference in New Issue
Block a user