Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision

Root cause of remaining ghosting: our single-stage temporal padding
(N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1.
For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame.
The pipeline then processed misaligned LQ→output frame mapping.

Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach:
1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21)
2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates

Also:
- kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache)
- local_range default 9→11 (more stable temporal consistency)
- sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64
  (matches naxci1 reference precision for embeddings and RoPE)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 18:06:46 +01:00
parent fa250897a2
commit 76dff7e573
3 changed files with 36 additions and 34 deletions

View File

@@ -289,9 +289,9 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
def sinusoidal_embedding_1d(dim, position):
half_dim = max(dim // 2, 1)
scale = torch.arange(half_dim, dtype=torch.float32, device=position.device)
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
inv_freq = torch.pow(10000.0, -scale / half_dim)
sinusoid = torch.outer(position.to(torch.float32), inv_freq)
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
@@ -305,9 +305,9 @@ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
half_dim = max(dim // 2, 1)
base = torch.arange(0, dim, 2, dtype=torch.float32)[:half_dim]
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
freqs = torch.pow(theta, -base / max(dim, 1))
steps = torch.arange(end, dtype=torch.float32)
steps = torch.arange(end, dtype=torch.float64)
angles = torch.outer(steps, freqs)
return torch.polar(torch.ones_like(angles), angles)
@@ -315,8 +315,7 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
orig_dtype = x.dtype
work_dtype = torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype
reshaped = x.to(work_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
x_complex = torch.view_as_complex(reshaped)
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
x_out = torch.view_as_real(x_complex * freqs).flatten(2)