diff --git a/inference.py b/inference.py index 6621eee..9093580 100644 --- a/inference.py +++ b/inference.py @@ -727,55 +727,62 @@ class FlashVSRModel: return sw, sh, tw, th @staticmethod - def _pad_video_5d(video): - """Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25). - - The pipeline requires num_frames % 4 == 1 and the streaming loop - produces 8*P+17 output frames where P = (num_frames-1)//8 - 2. - Using % 8 == 1 ensures output == num_frames (no wasted iterations). - - Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ... - """ - n = video.shape[2] - target = max(n, 25) # minimum for streaming loop (P >= 1) - remainder = (target - 1) % 8 - if remainder != 0: - target += 8 - remainder - added = target - n - if added > 0: - pad = video[:, :, -1:].repeat(1, 1, added, 1, 1) - video = torch.cat([video, pad], dim=2) - return video, added + def _align_frames(n): + """Round n to the form 8k+1 (pipeline requirement).""" + return 0 if n < 1 else ((n - 1) // 8) * 8 + 1 @staticmethod - def _restore_video_sequence(result, added_frames, expected): - """Trim pipeline output to the expected frame count.""" - if result.shape[0] > expected: - result = result[:expected] - elif result.shape[0] < expected: - pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:]) + def _restore_video_sequence(result, original_count): + """Strip 2 front warmup frames and trim to original count.""" + # Strip the first 2 frames (pipeline warmup from front dummy padding) + if result.shape[0] > 2: + result = result[2:] + # Trim or pad to exact expected count + if result.shape[0] > original_count: + result = result[:original_count] + elif result.shape[0] < original_count: + pad = result[-1:].expand(original_count - result.shape[0], *result.shape[1:]) result = torch.cat([result, pad], dim=0) return result def _prepare_video(self, frames, scale): - """Convert [F, H, W, C] [0,1] frames to padded [1, C, F, H, W] [-1,1]. + """Convert [F, H, W, C] [0,1] frames to padded [1, C, F_aligned, H, W] [-1,1]. - Bicubic-upscales each frame to the target resolution, normalizes to - [-1, 1], then applies temporal padding for the pipeline. + Matches the 1038lab reference preprocessing: + 1. Add 2 tail copies of last frame + 2. Add 2 front dummy copies of first frame + back padding to align_frames(N+4) + 3. Bicubic-upscale each frame to target resolution + 4. Normalize to [-1, 1] - Returns: - video: [1, C, F_padded, H, W] tensor - th, tw: padded spatial dimensions - nf: padded frame count (= video.shape[2]) - sh, sw: actual (unpadded) spatial dimensions - added: number of alignment-padding frames added + The 2 front dummies are required by the pipeline's LQ conditioning + indexing. The pipeline outputs corresponding warmup frames that are + stripped by _restore_video_sequence. """ N, H, W, C = frames.shape sw, sh, tw, th = self._compute_dims(W, H, scale) + # Pad sequence: 2 tail copies + alignment (matches _pad_video_sequence) + padded_n = N + 2 + added_alignment = 0 + remainder = (padded_n - 5) % 8 + if remainder != 0: + added_alignment = 8 - remainder + padded_n += added_alignment + + # Total with 2 front dummies + back padding (matches prepare_video) + aligned = self._align_frames(padded_n + 4) + processed = [] - for i in range(N): - frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W] + for i in range(aligned): + if i < 2: + idx = 0 # 2 front dummy frames (copy of first) + elif i > padded_n + 1: + idx = N - 1 # back padding (copy of last) + else: + src = i - 2 # shifted by 2 for the front dummies + idx = min(src, N - 1) # clamp to original range + + frame = frames[idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W] upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False) pad_h, pad_w = th - sh, tw - sw if pad_h > 0 or pad_w > 0: @@ -784,12 +791,9 @@ class FlashVSRModel: processed.append(normalized.squeeze(0).cpu().to(self.dtype)) video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0) + nf = video.shape[2] # = aligned - # Apply temporal padding (tail + alignment) - video, added = self._pad_video_5d(video) - nf = video.shape[2] - - return video, th, tw, nf, sh, sw, added + return video, th, tw, nf, sh, sw, padded_n @staticmethod def _to_frames(video): @@ -832,8 +836,8 @@ class FlashVSRModel: original_count = frames.shape[0] - # Prepare video tensor (bicubic upscale + pad) - video, th, tw, nf, sh, sw, added_frames = self._prepare_video(frames, scale) + # Prepare video tensor (bicubic upscale + pad with 2 front dummies) + video, th, tw, nf, sh, sw, padded_n = self._prepare_video(frames, scale) # Move LQ video to compute device (except for "long" mode which streams) if "long" not in self.pipe.__class__.__name__.lower(): @@ -853,10 +857,10 @@ class FlashVSRModel: color_fix=color_fix, unload_dit=unload_dit, ) - # Convert to ComfyUI format and crop spatial padding - result = self._to_frames(out).cpu()[:, :sh, :sw, :] + # Convert to ComfyUI format: crop to padded_n frames + spatial padding + result = self._to_frames(out).cpu()[:padded_n, :sh, :sw, :] - # Restore original frame count (strip temporal padding + warmup) - result = self._restore_video_sequence(result, added_frames, original_count) + # Strip 2 front warmup frames and trim to original count + result = self._restore_video_sequence(result, original_count) return result diff --git a/nodes.py b/nodes.py index c1c2120..39d95f4 100644 --- a/nodes.py +++ b/nodes.py @@ -1731,14 +1731,16 @@ class FlashVSRUpscale: chunks.append((prev_start, last_end)) # Estimate total pipeline steps for progress bar - # Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25 + # Mirrors _prepare_video: N+2 tail, align (n-5)%8, then align_frames(padded+4) total_steps = 0 for cs, ce in chunks: - nf = max(ce - cs, 25) - remainder = (nf - 1) % 8 + n = ce - cs + padded = n + 2 + remainder = (padded - 5) % 8 if remainder != 0: - nf += 8 - remainder - total_steps += max(1, (nf - 1) // 8 - 2) + padded += 8 - remainder + aligned = FlashVSRModel._align_frames(padded + 4) + total_steps += max(1, (aligned - 1) // 8 - 2) pbar = ProgressBar(total_steps) step_ref = [0]