diff --git a/inference.py b/inference.py index a4809a2..6621eee 100644 --- a/inference.py +++ b/inference.py @@ -728,30 +728,28 @@ class FlashVSRModel: @staticmethod def _pad_video_5d(video): - """Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline. + """Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25). - Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring - the pipeline's streaming loop gets correct iteration counts. + The pipeline requires num_frames % 4 == 1 and the streaming loop + produces 8*P+17 output frames where P = (num_frames-1)//8 - 2. + Using % 8 == 1 ensures output == num_frames (no wasted iterations). + + Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ... """ - tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1) - video = torch.cat([video, tail], dim=2) - added = 0 - remainder = (video.shape[2] + 2 - 5) % 8 + n = video.shape[2] + target = max(n, 25) # minimum for streaming loop (P >= 1) + remainder = (target - 1) % 8 if remainder != 0: - added = 8 - remainder + target += 8 - remainder + added = target - n + if added > 0: pad = video[:, :, -1:].repeat(1, 1, added, 1, 1) video = torch.cat([video, pad], dim=2) return video, added @staticmethod def _restore_video_sequence(result, added_frames, expected): - """Strip padding and warmup frames from the output.""" - if added_frames > 0 and result.shape[0] > added_frames: - result = result[:-added_frames] - # Strip the first 2 pipeline warmup frames - if result.shape[0] > 2: - result = result[2:] - # Adjust to exact expected count + """Trim pipeline output to the expected frame count.""" if result.shape[0] > expected: result = result[:expected] elif result.shape[0] < expected: diff --git a/nodes.py b/nodes.py index 2238b3d..c1c2120 100644 --- a/nodes.py +++ b/nodes.py @@ -1731,14 +1731,14 @@ class FlashVSRUpscale: chunks.append((prev_start, last_end)) # Estimate total pipeline steps for progress bar - # Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8 + # Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25 total_steps = 0 for cs, ce in chunks: - padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d - remainder = (padded_n + 2 - 5) % 8 + nf = max(ce - cs, 25) + remainder = (nf - 1) % 8 if remainder != 0: - padded_n += 8 - remainder - total_steps += max(1, (padded_n - 1) // 8 - 2) + nf += 8 - remainder + total_steps += max(1, (nf - 1) // 8 - 2) pbar = ProgressBar(total_steps) step_ref = [0]