From 94d98186755df8bb013d2c0190001b3913bd4c70 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 13 Feb 2026 17:10:12 +0100 Subject: [PATCH] Fix FlashVSR quality: match naxci1 reference preprocessing - Remove front dummy frames (not used by reference implementation) - Use centered reflect padding instead of right/bottom replicate - Crop output from center matching padding offsets - Simplify temporal padding to 8k+1 alignment - Update progress bar estimation to match new formula Co-Authored-By: Claude Opus 4.6 --- inference.py | 106 +++++++++++++++++++++++++-------------------------- nodes.py | 11 +++--- 2 files changed, 57 insertions(+), 60 deletions(-) diff --git a/inference.py b/inference.py index 9093580..736ce36 100644 --- a/inference.py +++ b/inference.py @@ -727,73 +727,70 @@ class FlashVSRModel: return sw, sh, tw, th @staticmethod - def _align_frames(n): - """Round n to the form 8k+1 (pipeline requirement).""" - return 0 if n < 1 else ((n - 1) // 8) * 8 + 1 - - @staticmethod - def _restore_video_sequence(result, original_count): - """Strip 2 front warmup frames and trim to original count.""" - # Strip the first 2 frames (pipeline warmup from front dummy padding) - if result.shape[0] > 2: - result = result[2:] - # Trim or pad to exact expected count - if result.shape[0] > original_count: - result = result[:original_count] - elif result.shape[0] < original_count: - pad = result[-1:].expand(original_count - result.shape[0], *result.shape[1:]) + def _restore_video_sequence(result, expected): + """Trim pipeline output to the expected frame count.""" + if result.shape[0] > expected: + result = result[:expected] + elif result.shape[0] < expected: + pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:]) result = torch.cat([result, pad], dim=0) return result def _prepare_video(self, frames, scale): - """Convert [F, H, W, C] [0,1] frames to padded [1, C, F_aligned, H, W] [-1,1]. + """Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1]. - Matches the 1038lab reference preprocessing: - 1. Add 2 tail copies of last frame - 2. Add 2 front dummy copies of first frame + back padding to align_frames(N+4) - 3. Bicubic-upscale each frame to target resolution - 4. Normalize to [-1, 1] + Matches naxci1/ComfyUI-FlashVSR_Stable preprocessing: + 1. Bicubic-upscale each frame to target resolution + 2. Centered symmetric padding to 128-pixel alignment (reflect mode) + 3. Normalize to [-1, 1] + 4. Temporal padding: repeat last frame to reach 8k+1 count - The 2 front dummies are required by the pipeline's LQ conditioning - indexing. The pipeline outputs corresponding warmup frames that are - stripped by _restore_video_sequence. + No front dummy frames — the pipeline handles LQ indexing correctly + starting from frame 0. + + Returns: + video: [1, C, F_padded, H, W] tensor + th, tw: padded spatial dimensions + nf: padded frame count + sh, sw: actual (unpadded) spatial dimensions + pad_top, pad_left: spatial padding offsets for output cropping """ N, H, W, C = frames.shape sw, sh, tw, th = self._compute_dims(W, H, scale) - # Pad sequence: 2 tail copies + alignment (matches _pad_video_sequence) - padded_n = N + 2 - added_alignment = 0 - remainder = (padded_n - 5) % 8 - if remainder != 0: - added_alignment = 8 - remainder - padded_n += added_alignment - - # Total with 2 front dummies + back padding (matches prepare_video) - aligned = self._align_frames(padded_n + 4) + # Centered spatial padding offsets + pad_top = (th - sh) // 2 + pad_bottom = th - sh - pad_top + pad_left = (tw - sw) // 2 + pad_right = tw - sw - pad_left processed = [] - for i in range(aligned): - if i < 2: - idx = 0 # 2 front dummy frames (copy of first) - elif i > padded_n + 1: - idx = N - 1 # back padding (copy of last) - else: - src = i - 2 # shifted by 2 for the front dummies - idx = min(src, N - 1) # clamp to original range - - frame = frames[idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W] + for i in range(N): + frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W] upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False) - pad_h, pad_w = th - sh, tw - sw - if pad_h > 0 or pad_w > 0: - upscaled = F.pad(upscaled, (0, pad_w, 0, pad_h), mode='replicate') + if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0: + # Centered reflect padding (matches naxci1 reference) + try: + upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect') + except RuntimeError: + # Reflect requires pad < input size; fall back to replicate + upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate') normalized = upscaled * 2.0 - 1.0 processed.append(normalized.squeeze(0).cpu().to(self.dtype)) video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0) - nf = video.shape[2] # = aligned - return video, th, tw, nf, sh, sw, padded_n + # Temporal padding: repeat last frame to reach 8k+1 (pipeline requirement) + target = max(N, 25) # minimum 25 for streaming loop (P >= 1) + remainder = (target - 1) % 8 + if remainder != 0: + target += 8 - remainder + if target > N: + pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1) + video = torch.cat([video, pad], dim=2) + nf = video.shape[2] + + return video, th, tw, nf, sh, sw, pad_top, pad_left @staticmethod def _to_frames(video): @@ -836,8 +833,8 @@ class FlashVSRModel: original_count = frames.shape[0] - # Prepare video tensor (bicubic upscale + pad with 2 front dummies) - video, th, tw, nf, sh, sw, padded_n = self._prepare_video(frames, scale) + # Prepare video tensor (bicubic upscale + centered pad) + video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale) # Move LQ video to compute device (except for "long" mode which streams) if "long" not in self.pipe.__class__.__name__.lower(): @@ -857,10 +854,11 @@ class FlashVSRModel: color_fix=color_fix, unload_dit=unload_dit, ) - # Convert to ComfyUI format: crop to padded_n frames + spatial padding - result = self._to_frames(out).cpu()[:padded_n, :sh, :sw, :] + # Convert to ComfyUI format with centered spatial crop + result = self._to_frames(out).cpu() + result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :] - # Strip 2 front warmup frames and trim to original count + # Trim to original frame count result = self._restore_video_sequence(result, original_count) return result diff --git a/nodes.py b/nodes.py index 39d95f4..bc48012 100644 --- a/nodes.py +++ b/nodes.py @@ -1731,16 +1731,15 @@ class FlashVSRUpscale: chunks.append((prev_start, last_end)) # Estimate total pipeline steps for progress bar - # Mirrors _prepare_video: N+2 tail, align (n-5)%8, then align_frames(padded+4) + # Mirrors _prepare_video: target = max(N, 25), round up to 8k+1 total_steps = 0 for cs, ce in chunks: n = ce - cs - padded = n + 2 - remainder = (padded - 5) % 8 + target = max(n, 25) + remainder = (target - 1) % 8 if remainder != 0: - padded += 8 - remainder - aligned = FlashVSRModel._align_frames(padded + 4) - total_steps += max(1, (aligned - 1) // 8 - 2) + target += 8 - remainder + total_steps += max(1, (target - 1) // 8 - 2) pbar = ProgressBar(total_steps) step_ref = [0]