Fix FlashVSR quality: match naxci1 reference preprocessing

- Remove front dummy frames (not used by reference implementation)
- Use centered reflect padding instead of right/bottom replicate
- Crop output from center matching padding offsets
- Simplify temporal padding to 8k+1 alignment
- Update progress bar estimation to match new formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 17:10:12 +01:00
parent ea84ffef7c
commit 94d9818675
2 changed files with 57 additions and 60 deletions

View File

@@ -727,73 +727,70 @@ class FlashVSRModel:
return sw, sh, tw, th return sw, sh, tw, th
@staticmethod @staticmethod
def _align_frames(n): def _restore_video_sequence(result, expected):
"""Round n to the form 8k+1 (pipeline requirement).""" """Trim pipeline output to the expected frame count."""
return 0 if n < 1 else ((n - 1) // 8) * 8 + 1 if result.shape[0] > expected:
result = result[:expected]
@staticmethod elif result.shape[0] < expected:
def _restore_video_sequence(result, original_count): pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
"""Strip 2 front warmup frames and trim to original count."""
# Strip the first 2 frames (pipeline warmup from front dummy padding)
if result.shape[0] > 2:
result = result[2:]
# Trim or pad to exact expected count
if result.shape[0] > original_count:
result = result[:original_count]
elif result.shape[0] < original_count:
pad = result[-1:].expand(original_count - result.shape[0], *result.shape[1:])
result = torch.cat([result, pad], dim=0) result = torch.cat([result, pad], dim=0)
return result return result
def _prepare_video(self, frames, scale): def _prepare_video(self, frames, scale):
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_aligned, H, W] [-1,1]. """Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
Matches the 1038lab reference preprocessing: Matches naxci1/ComfyUI-FlashVSR_Stable preprocessing:
1. Add 2 tail copies of last frame 1. Bicubic-upscale each frame to target resolution
2. Add 2 front dummy copies of first frame + back padding to align_frames(N+4) 2. Centered symmetric padding to 128-pixel alignment (reflect mode)
3. Bicubic-upscale each frame to target resolution 3. Normalize to [-1, 1]
4. Normalize to [-1, 1] 4. Temporal padding: repeat last frame to reach 8k+1 count
The 2 front dummies are required by the pipeline's LQ conditioning No front dummy frames — the pipeline handles LQ indexing correctly
indexing. The pipeline outputs corresponding warmup frames that are starting from frame 0.
stripped by _restore_video_sequence.
Returns:
video: [1, C, F_padded, H, W] tensor
th, tw: padded spatial dimensions
nf: padded frame count
sh, sw: actual (unpadded) spatial dimensions
pad_top, pad_left: spatial padding offsets for output cropping
""" """
N, H, W, C = frames.shape N, H, W, C = frames.shape
sw, sh, tw, th = self._compute_dims(W, H, scale) sw, sh, tw, th = self._compute_dims(W, H, scale)
# Pad sequence: 2 tail copies + alignment (matches _pad_video_sequence) # Centered spatial padding offsets
padded_n = N + 2 pad_top = (th - sh) // 2
added_alignment = 0 pad_bottom = th - sh - pad_top
remainder = (padded_n - 5) % 8 pad_left = (tw - sw) // 2
if remainder != 0: pad_right = tw - sw - pad_left
added_alignment = 8 - remainder
padded_n += added_alignment
# Total with 2 front dummies + back padding (matches prepare_video)
aligned = self._align_frames(padded_n + 4)
processed = [] processed = []
for i in range(aligned): for i in range(N):
if i < 2: frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
idx = 0 # 2 front dummy frames (copy of first)
elif i > padded_n + 1:
idx = N - 1 # back padding (copy of last)
else:
src = i - 2 # shifted by 2 for the front dummies
idx = min(src, N - 1) # clamp to original range
frame = frames[idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False) upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
pad_h, pad_w = th - sh, tw - sw if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
if pad_h > 0 or pad_w > 0: # Centered reflect padding (matches naxci1 reference)
upscaled = F.pad(upscaled, (0, pad_w, 0, pad_h), mode='replicate') try:
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')
except RuntimeError:
# Reflect requires pad < input size; fall back to replicate
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
normalized = upscaled * 2.0 - 1.0 normalized = upscaled * 2.0 - 1.0
processed.append(normalized.squeeze(0).cpu().to(self.dtype)) processed.append(normalized.squeeze(0).cpu().to(self.dtype))
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0) video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
nf = video.shape[2] # = aligned
return video, th, tw, nf, sh, sw, padded_n # Temporal padding: repeat last frame to reach 8k+1 (pipeline requirement)
target = max(N, 25) # minimum 25 for streaming loop (P >= 1)
remainder = (target - 1) % 8
if remainder != 0:
target += 8 - remainder
if target > N:
pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1)
video = torch.cat([video, pad], dim=2)
nf = video.shape[2]
return video, th, tw, nf, sh, sw, pad_top, pad_left
@staticmethod @staticmethod
def _to_frames(video): def _to_frames(video):
@@ -836,8 +833,8 @@ class FlashVSRModel:
original_count = frames.shape[0] original_count = frames.shape[0]
# Prepare video tensor (bicubic upscale + pad with 2 front dummies) # Prepare video tensor (bicubic upscale + centered pad)
video, th, tw, nf, sh, sw, padded_n = self._prepare_video(frames, scale) video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale)
# Move LQ video to compute device (except for "long" mode which streams) # Move LQ video to compute device (except for "long" mode which streams)
if "long" not in self.pipe.__class__.__name__.lower(): if "long" not in self.pipe.__class__.__name__.lower():
@@ -857,10 +854,11 @@ class FlashVSRModel:
color_fix=color_fix, unload_dit=unload_dit, color_fix=color_fix, unload_dit=unload_dit,
) )
# Convert to ComfyUI format: crop to padded_n frames + spatial padding # Convert to ComfyUI format with centered spatial crop
result = self._to_frames(out).cpu()[:padded_n, :sh, :sw, :] result = self._to_frames(out).cpu()
result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :]
# Strip 2 front warmup frames and trim to original count # Trim to original frame count
result = self._restore_video_sequence(result, original_count) result = self._restore_video_sequence(result, original_count)
return result return result

View File

@@ -1731,16 +1731,15 @@ class FlashVSRUpscale:
chunks.append((prev_start, last_end)) chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar # Estimate total pipeline steps for progress bar
# Mirrors _prepare_video: N+2 tail, align (n-5)%8, then align_frames(padded+4) # Mirrors _prepare_video: target = max(N, 25), round up to 8k+1
total_steps = 0 total_steps = 0
for cs, ce in chunks: for cs, ce in chunks:
n = ce - cs n = ce - cs
padded = n + 2 target = max(n, 25)
remainder = (padded - 5) % 8 remainder = (target - 1) % 8
if remainder != 0: if remainder != 0:
padded += 8 - remainder target += 8 - remainder
aligned = FlashVSRModel._align_frames(padded + 4) total_steps += max(1, (target - 1) // 8 - 2)
total_steps += max(1, (aligned - 1) // 8 - 2)
pbar = ProgressBar(total_steps) pbar = ProgressBar(total_steps)
step_ref = [0] step_ref = [0]