Fix FlashVSR frame padding to match pipeline requirements
The pipeline requires num_frames % 4 == 1. Our old _pad_video_5d used a wrong formula that produced non-conforming counts (e.g. 33 input → 35 padded → pipeline rounds to 37, wasting VRAM). New padding uses num_frames % 8 == 1 (also satisfies % 4 == 1), which ensures the streaming loop output exactly matches num_frames with zero waste. Optimal input counts: 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105. Also removes incorrect 2-frame warmup stripping from _restore_video_sequence — the pipeline output doesn't have warmup artifacts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
28
inference.py
28
inference.py
@@ -728,30 +728,28 @@ class FlashVSRModel:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pad_video_5d(video):
|
def _pad_video_5d(video):
|
||||||
"""Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline.
|
"""Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25).
|
||||||
|
|
||||||
Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring
|
The pipeline requires num_frames % 4 == 1 and the streaming loop
|
||||||
the pipeline's streaming loop gets correct iteration counts.
|
produces 8*P+17 output frames where P = (num_frames-1)//8 - 2.
|
||||||
|
Using % 8 == 1 ensures output == num_frames (no wasted iterations).
|
||||||
|
|
||||||
|
Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ...
|
||||||
"""
|
"""
|
||||||
tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1)
|
n = video.shape[2]
|
||||||
video = torch.cat([video, tail], dim=2)
|
target = max(n, 25) # minimum for streaming loop (P >= 1)
|
||||||
added = 0
|
remainder = (target - 1) % 8
|
||||||
remainder = (video.shape[2] + 2 - 5) % 8
|
|
||||||
if remainder != 0:
|
if remainder != 0:
|
||||||
added = 8 - remainder
|
target += 8 - remainder
|
||||||
|
added = target - n
|
||||||
|
if added > 0:
|
||||||
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
|
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
|
||||||
video = torch.cat([video, pad], dim=2)
|
video = torch.cat([video, pad], dim=2)
|
||||||
return video, added
|
return video, added
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _restore_video_sequence(result, added_frames, expected):
|
def _restore_video_sequence(result, added_frames, expected):
|
||||||
"""Strip padding and warmup frames from the output."""
|
"""Trim pipeline output to the expected frame count."""
|
||||||
if added_frames > 0 and result.shape[0] > added_frames:
|
|
||||||
result = result[:-added_frames]
|
|
||||||
# Strip the first 2 pipeline warmup frames
|
|
||||||
if result.shape[0] > 2:
|
|
||||||
result = result[2:]
|
|
||||||
# Adjust to exact expected count
|
|
||||||
if result.shape[0] > expected:
|
if result.shape[0] > expected:
|
||||||
result = result[:expected]
|
result = result[:expected]
|
||||||
elif result.shape[0] < expected:
|
elif result.shape[0] < expected:
|
||||||
|
|||||||
10
nodes.py
10
nodes.py
@@ -1731,14 +1731,14 @@ class FlashVSRUpscale:
|
|||||||
chunks.append((prev_start, last_end))
|
chunks.append((prev_start, last_end))
|
||||||
|
|
||||||
# Estimate total pipeline steps for progress bar
|
# Estimate total pipeline steps for progress bar
|
||||||
# Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8
|
# Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
for cs, ce in chunks:
|
for cs, ce in chunks:
|
||||||
padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d
|
nf = max(ce - cs, 25)
|
||||||
remainder = (padded_n + 2 - 5) % 8
|
remainder = (nf - 1) % 8
|
||||||
if remainder != 0:
|
if remainder != 0:
|
||||||
padded_n += 8 - remainder
|
nf += 8 - remainder
|
||||||
total_steps += max(1, (padded_n - 1) // 8 - 2)
|
total_steps += max(1, (nf - 1) // 8 - 2)
|
||||||
|
|
||||||
pbar = ProgressBar(total_steps)
|
pbar = ProgressBar(total_steps)
|
||||||
step_ref = [0]
|
step_ref = [0]
|
||||||
|
|||||||
Reference in New Issue
Block a user