Fix FlashVSR ghosting: restore 2 front dummy frames matching reference
The pipeline's LQ conditioning indexing expects 2 front dummy frames (copies of first frame) as warmup. Our previous refactoring removed these, shifting all LQ conditioning by 2 frames and causing severe ghosting artifacts. Now matches the 1038lab reference preprocessing exactly: 1. _prepare_video: 2 tail copies + alignment + 2 front dummies + back padding 2. _restore_video_sequence: strip first 2 warmup frames + trim to original count 3. Crop pipeline output to padded_n before restoration Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
98
inference.py
98
inference.py
@@ -727,55 +727,62 @@ class FlashVSRModel:
|
||||
return sw, sh, tw, th
|
||||
|
||||
@staticmethod
|
||||
def _pad_video_5d(video):
|
||||
"""Pad [1, C, F, H, W] video for pipeline (num_frames % 8 == 1, min 25).
|
||||
|
||||
The pipeline requires num_frames % 4 == 1 and the streaming loop
|
||||
produces 8*P+17 output frames where P = (num_frames-1)//8 - 2.
|
||||
Using % 8 == 1 ensures output == num_frames (no wasted iterations).
|
||||
|
||||
Optimal input frame counts (zero padding): 25, 33, 41, 49, 57, 65, ...
|
||||
"""
|
||||
n = video.shape[2]
|
||||
target = max(n, 25) # minimum for streaming loop (P >= 1)
|
||||
remainder = (target - 1) % 8
|
||||
if remainder != 0:
|
||||
target += 8 - remainder
|
||||
added = target - n
|
||||
if added > 0:
|
||||
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
return video, added
|
||||
def _align_frames(n):
|
||||
"""Round n to the form 8k+1 (pipeline requirement)."""
|
||||
return 0 if n < 1 else ((n - 1) // 8) * 8 + 1
|
||||
|
||||
@staticmethod
|
||||
def _restore_video_sequence(result, added_frames, expected):
|
||||
"""Trim pipeline output to the expected frame count."""
|
||||
if result.shape[0] > expected:
|
||||
result = result[:expected]
|
||||
elif result.shape[0] < expected:
|
||||
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
|
||||
def _restore_video_sequence(result, original_count):
|
||||
"""Strip 2 front warmup frames and trim to original count."""
|
||||
# Strip the first 2 frames (pipeline warmup from front dummy padding)
|
||||
if result.shape[0] > 2:
|
||||
result = result[2:]
|
||||
# Trim or pad to exact expected count
|
||||
if result.shape[0] > original_count:
|
||||
result = result[:original_count]
|
||||
elif result.shape[0] < original_count:
|
||||
pad = result[-1:].expand(original_count - result.shape[0], *result.shape[1:])
|
||||
result = torch.cat([result, pad], dim=0)
|
||||
return result
|
||||
|
||||
def _prepare_video(self, frames, scale):
|
||||
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F, H, W] [-1,1].
|
||||
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_aligned, H, W] [-1,1].
|
||||
|
||||
Bicubic-upscales each frame to the target resolution, normalizes to
|
||||
[-1, 1], then applies temporal padding for the pipeline.
|
||||
Matches the 1038lab reference preprocessing:
|
||||
1. Add 2 tail copies of last frame
|
||||
2. Add 2 front dummy copies of first frame + back padding to align_frames(N+4)
|
||||
3. Bicubic-upscale each frame to target resolution
|
||||
4. Normalize to [-1, 1]
|
||||
|
||||
Returns:
|
||||
video: [1, C, F_padded, H, W] tensor
|
||||
th, tw: padded spatial dimensions
|
||||
nf: padded frame count (= video.shape[2])
|
||||
sh, sw: actual (unpadded) spatial dimensions
|
||||
added: number of alignment-padding frames added
|
||||
The 2 front dummies are required by the pipeline's LQ conditioning
|
||||
indexing. The pipeline outputs corresponding warmup frames that are
|
||||
stripped by _restore_video_sequence.
|
||||
"""
|
||||
N, H, W, C = frames.shape
|
||||
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
||||
|
||||
# Pad sequence: 2 tail copies + alignment (matches _pad_video_sequence)
|
||||
padded_n = N + 2
|
||||
added_alignment = 0
|
||||
remainder = (padded_n - 5) % 8
|
||||
if remainder != 0:
|
||||
added_alignment = 8 - remainder
|
||||
padded_n += added_alignment
|
||||
|
||||
# Total with 2 front dummies + back padding (matches prepare_video)
|
||||
aligned = self._align_frames(padded_n + 4)
|
||||
|
||||
processed = []
|
||||
for i in range(N):
|
||||
frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
||||
for i in range(aligned):
|
||||
if i < 2:
|
||||
idx = 0 # 2 front dummy frames (copy of first)
|
||||
elif i > padded_n + 1:
|
||||
idx = N - 1 # back padding (copy of last)
|
||||
else:
|
||||
src = i - 2 # shifted by 2 for the front dummies
|
||||
idx = min(src, N - 1) # clamp to original range
|
||||
|
||||
frame = frames[idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
||||
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
||||
pad_h, pad_w = th - sh, tw - sw
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
@@ -784,12 +791,9 @@ class FlashVSRModel:
|
||||
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
||||
|
||||
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
||||
nf = video.shape[2] # = aligned
|
||||
|
||||
# Apply temporal padding (tail + alignment)
|
||||
video, added = self._pad_video_5d(video)
|
||||
nf = video.shape[2]
|
||||
|
||||
return video, th, tw, nf, sh, sw, added
|
||||
return video, th, tw, nf, sh, sw, padded_n
|
||||
|
||||
@staticmethod
|
||||
def _to_frames(video):
|
||||
@@ -832,8 +836,8 @@ class FlashVSRModel:
|
||||
|
||||
original_count = frames.shape[0]
|
||||
|
||||
# Prepare video tensor (bicubic upscale + pad)
|
||||
video, th, tw, nf, sh, sw, added_frames = self._prepare_video(frames, scale)
|
||||
# Prepare video tensor (bicubic upscale + pad with 2 front dummies)
|
||||
video, th, tw, nf, sh, sw, padded_n = self._prepare_video(frames, scale)
|
||||
|
||||
# Move LQ video to compute device (except for "long" mode which streams)
|
||||
if "long" not in self.pipe.__class__.__name__.lower():
|
||||
@@ -853,10 +857,10 @@ class FlashVSRModel:
|
||||
color_fix=color_fix, unload_dit=unload_dit,
|
||||
)
|
||||
|
||||
# Convert to ComfyUI format and crop spatial padding
|
||||
result = self._to_frames(out).cpu()[:, :sh, :sw, :]
|
||||
# Convert to ComfyUI format: crop to padded_n frames + spatial padding
|
||||
result = self._to_frames(out).cpu()[:padded_n, :sh, :sw, :]
|
||||
|
||||
# Restore original frame count (strip temporal padding + warmup)
|
||||
result = self._restore_video_sequence(result, added_frames, original_count)
|
||||
# Strip 2 front warmup frames and trim to original count
|
||||
result = self._restore_video_sequence(result, original_count)
|
||||
|
||||
return result
|
||||
|
||||
12
nodes.py
12
nodes.py
@@ -1731,14 +1731,16 @@ class FlashVSRUpscale:
|
||||
chunks.append((prev_start, last_end))
|
||||
|
||||
# Estimate total pipeline steps for progress bar
|
||||
# Mirrors _pad_video_5d: pad to num_frames % 8 == 1, min 25
|
||||
# Mirrors _prepare_video: N+2 tail, align (n-5)%8, then align_frames(padded+4)
|
||||
total_steps = 0
|
||||
for cs, ce in chunks:
|
||||
nf = max(ce - cs, 25)
|
||||
remainder = (nf - 1) % 8
|
||||
n = ce - cs
|
||||
padded = n + 2
|
||||
remainder = (padded - 5) % 8
|
||||
if remainder != 0:
|
||||
nf += 8 - remainder
|
||||
total_steps += max(1, (nf - 1) // 8 - 2)
|
||||
padded += 8 - remainder
|
||||
aligned = FlashVSRModel._align_frames(padded + 4)
|
||||
total_steps += max(1, (aligned - 1) // 8 - 2)
|
||||
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
|
||||
Reference in New Issue
Block a user