Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision

Root cause of remaining ghosting: our single-stage temporal padding
(N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1.
For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame.
The pipeline then processed misaligned LQ→output frame mapping.

Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach:
1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21)
2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates

Also:
- kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache)
- local_range default 9→11 (more stable temporal consistency)
- sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64
  (matches naxci1 reference precision for embeddings and RoPE)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 18:06:46 +01:00
parent fa250897a2
commit 76dff7e573
3 changed files with 36 additions and 34 deletions

View File

@@ -740,17 +740,22 @@ class FlashVSRModel:
result = torch.cat([result, pad], dim=0)
return result
@staticmethod
def _next_8n5(n, minimum=21):
"""Next integer >= n of the form 8k+5 (minimum 21)."""
if n < minimum:
return minimum
return ((n - 5 + 7) // 8) * 8 + 5
def _prepare_video(self, frames, scale):
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
Matches naxci1/ComfyUI-FlashVSR_Stable preprocessing:
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
1. Bicubic-upscale each frame to target resolution
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
3. Normalize to [-1, 1]
4. Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference)
No front dummy frames — the pipeline handles LQ indexing correctly
starting from frame 0.
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
Returns:
video: [1, C, F_padded, H, W] tensor
@@ -762,6 +767,12 @@ class FlashVSRModel:
N, H, W, C = frames.shape
sw, sh, tw, th = self._compute_dims(W, H, scale)
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
N_padded = self._next_8n5(N)
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
target = N_padded + 4
# Centered spatial padding offsets
pad_top = (th - sh) // 2
pad_bottom = th - sh - pad_top
@@ -769,8 +780,9 @@ class FlashVSRModel:
pad_right = tw - sw - pad_left
processed = []
for i in range(N):
frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
for i in range(target):
frame_idx = min(i, N - 1) # clamp to last real frame
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
# Centered reflect padding (matches naxci1 reference)
@@ -783,17 +795,6 @@ class FlashVSRModel:
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
# Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference)
num_with_pad = N + 4
target = ((num_with_pad - 1) // 8) * 8 + 1 # largest_8n1_leq
if target < 1:
target = 1
if target > N:
pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1)
video = torch.cat([video, pad], dim=2)
elif target < N:
video = video[:, :, :target, :, :]
nf = video.shape[2]
return video, th, tw, nf, sh, sw, pad_top, pad_left
@@ -812,7 +813,7 @@ class FlashVSRModel:
@torch.no_grad()
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
topk_ratio=2.0, kv_ratio=2.0, local_range=9,
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
color_fix=True, unload_dit=False, seed=1,
progress_bar_cmd=None):
"""Upscale video frames with FlashVSR.