Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision

Root cause of remaining ghosting: our single-stage temporal padding
(N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1.
For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame.
The pipeline then processed misaligned LQ→output frame mapping.

Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach:
1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21)
2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates

Also:
- kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache)
- local_range default 9→11 (more stable temporal consistency)
- sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64
  (matches naxci1 reference precision for embeddings and RoPE)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 18:06:46 +01:00
parent fa250897a2
commit 76dff7e573
3 changed files with 36 additions and 34 deletions

View File

@@ -1672,12 +1672,12 @@ class FlashVSRUpscale:
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
}),
"kv_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "KV cache ratio. Higher = better quality, more VRAM.",
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "KV cache ratio. Higher = better quality, more VRAM. 3.0 recommended.",
}),
"local_range": ([9, 11], {
"default": 9,
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability.",
"default": 11,
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability (recommended).",
}),
"color_fix": ("BOOLEAN", {
"default": True,
@@ -1731,12 +1731,14 @@ class FlashVSRUpscale:
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _prepare_video: largest_8n1_leq(N + 4)
# Mirrors _prepare_video two-stage padding: next_8n5(N) + 4
def _next_8n5(n, minimum=21):
return minimum if n < minimum else ((n - 5 + 7) // 8) * 8 + 5
total_steps = 0
for cs, ce in chunks:
n = ce - cs
num_with_pad = n + 4
target = ((num_with_pad - 1) // 8) * 8 + 1
target = _next_8n5(n) + 4 # always 8k+1
total_steps += max(1, (target - 1) // 8 - 2)
pbar = ProgressBar(total_steps)
@@ -1826,10 +1828,10 @@ class FlashVSRSegmentUpscale:
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"kv_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"local_range": ([9, 11], {
"default": 9,
"default": 11,
}),
"color_fix": ("BOOLEAN", {
"default": True,