Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision
Root cause of remaining ghosting: our single-stage temporal padding (N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1. For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame. The pipeline then processed misaligned LQ→output frame mapping. Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach: 1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21) 2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates Also: - kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache) - local_range default 9→11 (more stable temporal consistency) - sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64 (matches naxci1 reference precision for embeddings and RoPE) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -289,9 +289,9 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
half_dim = max(dim // 2, 1)
|
||||
scale = torch.arange(half_dim, dtype=torch.float32, device=position.device)
|
||||
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
|
||||
inv_freq = torch.pow(10000.0, -scale / half_dim)
|
||||
sinusoid = torch.outer(position.to(torch.float32), inv_freq)
|
||||
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
@@ -305,9 +305,9 @@ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
half_dim = max(dim // 2, 1)
|
||||
base = torch.arange(0, dim, 2, dtype=torch.float32)[:half_dim]
|
||||
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
|
||||
freqs = torch.pow(theta, -base / max(dim, 1))
|
||||
steps = torch.arange(end, dtype=torch.float32)
|
||||
steps = torch.arange(end, dtype=torch.float64)
|
||||
angles = torch.outer(steps, freqs)
|
||||
return torch.polar(torch.ones_like(angles), angles)
|
||||
|
||||
@@ -315,8 +315,7 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
orig_dtype = x.dtype
|
||||
work_dtype = torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype
|
||||
reshaped = x.to(work_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
x_complex = torch.view_as_complex(reshaped)
|
||||
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
|
||||
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
|
||||
|
||||
Reference in New Issue
Block a user