Fix FlashVSR ghosting: restore 2 front dummy frames matching reference

The pipeline's LQ conditioning indexing expects 2 front dummy frames
(copies of first frame) as warmup. Our previous refactoring removed
these, shifting all LQ conditioning by 2 frames and causing severe
ghosting artifacts.

Now matches the 1038lab reference preprocessing exactly:
1. _prepare_video: 2 tail copies + alignment + 2 front dummies + back padding
2. _restore_video_sequence: strip first 2 warmup frames + trim to original count
3. Crop pipeline output to padded_n before restoration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 16:49:46 +01:00
parent 4cc6e9c705
commit ea84ffef7c
2 changed files with 58 additions and 52 deletions

View File

@@ -727,55 +727,62 @@ class FlashVSRModel:
return sw, sh, tw, th
@staticmethod
def _pad_video_5d(video):
"""Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25).
The pipeline requires num_frames % 4 == 1 and the streaming loop
produces 8*P+17 output frames where P = (num_frames-1)//8 - 2.
Using % 8 == 1 ensures output == num_frames (no wasted iterations).
Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ...
"""
n = video.shape[2]
target = max(n, 25) # minimum for streaming loop (P >= 1)
remainder = (target - 1) % 8
if remainder != 0:
target += 8 - remainder
added = target - n
if added > 0:
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
video = torch.cat([video, pad], dim=2)
return video, added
def _align_frames(n):
"""Round n to the form 8k+1 (pipeline requirement)."""
return 0 if n < 1 else ((n - 1) // 8) * 8 + 1
@staticmethod
def _restore_video_sequence(result, added_frames, expected):
"""Trim pipeline output to the expected frame count."""
if result.shape[0] > expected:
result = result[:expected]
elif result.shape[0] < expected:
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
def _restore_video_sequence(result, original_count):
"""Strip 2 front warmup frames and trim to original count."""
# Strip the first 2 frames (pipeline warmup from front dummy padding)
if result.shape[0] > 2:
result = result[2:]
# Trim or pad to exact expected count
if result.shape[0] > original_count:
result = result[:original_count]
elif result.shape[0] < original_count:
pad = result[-1:].expand(original_count - result.shape[0], *result.shape[1:])
result = torch.cat([result, pad], dim=0)
return result
def _prepare_video(self, frames, scale):
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F, H, W] [-1,1].
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_aligned, H, W] [-1,1].
Bicubic-upscales each frame to the target resolution, normalizes to
[-1, 1], then applies temporal padding for the pipeline.
Matches the 1038lab reference preprocessing:
1. Add 2 tail copies of last frame
2. Add 2 front dummy copies of first frame + back padding to align_frames(N+4)
3. Bicubic-upscale each frame to target resolution
4. Normalize to [-1, 1]
Returns:
video: [1, C, F_padded, H, W] tensor
th, tw: padded spatial dimensions
nf: padded frame count (= video.shape[2])
sh, sw: actual (unpadded) spatial dimensions
added: number of alignment-padding frames added
The 2 front dummies are required by the pipeline's LQ conditioning
indexing. The pipeline outputs corresponding warmup frames that are
stripped by _restore_video_sequence.
"""
N, H, W, C = frames.shape
sw, sh, tw, th = self._compute_dims(W, H, scale)
# Pad sequence: 2 tail copies + alignment (matches _pad_video_sequence)
padded_n = N + 2
added_alignment = 0
remainder = (padded_n - 5) % 8
if remainder != 0:
added_alignment = 8 - remainder
padded_n += added_alignment
# Total with 2 front dummies + back padding (matches prepare_video)
aligned = self._align_frames(padded_n + 4)
processed = []
for i in range(N):
frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
for i in range(aligned):
if i < 2:
idx = 0 # 2 front dummy frames (copy of first)
elif i > padded_n + 1:
idx = N - 1 # back padding (copy of last)
else:
src = i - 2 # shifted by 2 for the front dummies
idx = min(src, N - 1) # clamp to original range
frame = frames[idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
pad_h, pad_w = th - sh, tw - sw
if pad_h > 0 or pad_w > 0:
@@ -784,12 +791,9 @@ class FlashVSRModel:
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
nf = video.shape[2] # = aligned
# Apply temporal padding (tail + alignment)
video, added = self._pad_video_5d(video)
nf = video.shape[2]
return video, th, tw, nf, sh, sw, added
return video, th, tw, nf, sh, sw, padded_n
@staticmethod
def _to_frames(video):
@@ -832,8 +836,8 @@ class FlashVSRModel:
original_count = frames.shape[0]
# Prepare video tensor (bicubic upscale + pad)
video, th, tw, nf, sh, sw, added_frames = self._prepare_video(frames, scale)
# Prepare video tensor (bicubic upscale + pad with 2 front dummies)
video, th, tw, nf, sh, sw, padded_n = self._prepare_video(frames, scale)
# Move LQ video to compute device (except for "long" mode which streams)
if "long" not in self.pipe.__class__.__name__.lower():
@@ -853,10 +857,10 @@ class FlashVSRModel:
color_fix=color_fix, unload_dit=unload_dit,
)
# Convert to ComfyUI format and crop spatial padding
result = self._to_frames(out).cpu()[:, :sh, :sw, :]
# Convert to ComfyUI format: crop to padded_n frames + spatial padding
result = self._to_frames(out).cpu()[:padded_n, :sh, :sw, :]
# Restore original frame count (strip temporal padding + warmup)
result = self._restore_video_sequence(result, added_frames, original_count)
# Strip 2 front warmup frames and trim to original count
result = self._restore_video_sequence(result, original_count)
return result

View File

@@ -1731,14 +1731,16 @@ class FlashVSRUpscale:
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25
# Mirrors _prepare_video: N+2 tail, align (n-5)%8, then align_frames(padded+4)
total_steps = 0
for cs, ce in chunks:
nf = max(ce - cs, 25)
remainder = (nf - 1) % 8
n = ce - cs
padded = n + 2
remainder = (padded - 5) % 8
if remainder != 0:
nf += 8 - remainder
total_steps += max(1, (nf - 1) // 8 - 2)
padded += 8 - remainder
aligned = FlashVSRModel._align_frames(padded + 4)
total_steps += max(1, (aligned - 1) // 8 - 2)
pbar = ProgressBar(total_steps)
step_ref = [0]