Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision
Root cause of remaining ghosting: our single-stage temporal padding (N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1. For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame. The pipeline then processed misaligned LQ→output frame mapping. Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach: 1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21) 2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates Also: - kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache) - local_range default 9→11 (more stable temporal consistency) - sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64 (matches naxci1 reference precision for embeddings and RoPE) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -289,9 +289,9 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
half_dim = max(dim // 2, 1)
|
half_dim = max(dim // 2, 1)
|
||||||
scale = torch.arange(half_dim, dtype=torch.float32, device=position.device)
|
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
|
||||||
inv_freq = torch.pow(10000.0, -scale / half_dim)
|
inv_freq = torch.pow(10000.0, -scale / half_dim)
|
||||||
sinusoid = torch.outer(position.to(torch.float32), inv_freq)
|
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
|
||||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
return x.to(position.dtype)
|
return x.to(position.dtype)
|
||||||
|
|
||||||
@@ -305,9 +305,9 @@ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
|||||||
|
|
||||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||||
half_dim = max(dim // 2, 1)
|
half_dim = max(dim // 2, 1)
|
||||||
base = torch.arange(0, dim, 2, dtype=torch.float32)[:half_dim]
|
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
|
||||||
freqs = torch.pow(theta, -base / max(dim, 1))
|
freqs = torch.pow(theta, -base / max(dim, 1))
|
||||||
steps = torch.arange(end, dtype=torch.float32)
|
steps = torch.arange(end, dtype=torch.float64)
|
||||||
angles = torch.outer(steps, freqs)
|
angles = torch.outer(steps, freqs)
|
||||||
return torch.polar(torch.ones_like(angles), angles)
|
return torch.polar(torch.ones_like(angles), angles)
|
||||||
|
|
||||||
@@ -315,8 +315,7 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
|||||||
def rope_apply(x, freqs, num_heads):
|
def rope_apply(x, freqs, num_heads):
|
||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
work_dtype = torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype
|
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||||
reshaped = x.to(work_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
|
||||||
x_complex = torch.view_as_complex(reshaped)
|
x_complex = torch.view_as_complex(reshaped)
|
||||||
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
|
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
|
||||||
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
|
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
|
||||||
|
|||||||
39
inference.py
39
inference.py
@@ -740,17 +740,22 @@ class FlashVSRModel:
|
|||||||
result = torch.cat([result, pad], dim=0)
|
result = torch.cat([result, pad], dim=0)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _next_8n5(n, minimum=21):
|
||||||
|
"""Next integer >= n of the form 8k+5 (minimum 21)."""
|
||||||
|
if n < minimum:
|
||||||
|
return minimum
|
||||||
|
return ((n - 5 + 7) // 8) * 8 + 5
|
||||||
|
|
||||||
def _prepare_video(self, frames, scale):
|
def _prepare_video(self, frames, scale):
|
||||||
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
|
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
|
||||||
|
|
||||||
Matches naxci1/ComfyUI-FlashVSR_Stable preprocessing:
|
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
|
||||||
1. Bicubic-upscale each frame to target resolution
|
1. Bicubic-upscale each frame to target resolution
|
||||||
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
|
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
|
||||||
3. Normalize to [-1, 1]
|
3. Normalize to [-1, 1]
|
||||||
4. Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference)
|
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
|
||||||
|
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
|
||||||
No front dummy frames — the pipeline handles LQ indexing correctly
|
|
||||||
starting from frame 0.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
video: [1, C, F_padded, H, W] tensor
|
video: [1, C, F_padded, H, W] tensor
|
||||||
@@ -762,6 +767,12 @@ class FlashVSRModel:
|
|||||||
N, H, W, C = frames.shape
|
N, H, W, C = frames.shape
|
||||||
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
||||||
|
|
||||||
|
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
|
||||||
|
N_padded = self._next_8n5(N)
|
||||||
|
|
||||||
|
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
|
||||||
|
target = N_padded + 4
|
||||||
|
|
||||||
# Centered spatial padding offsets
|
# Centered spatial padding offsets
|
||||||
pad_top = (th - sh) // 2
|
pad_top = (th - sh) // 2
|
||||||
pad_bottom = th - sh - pad_top
|
pad_bottom = th - sh - pad_top
|
||||||
@@ -769,8 +780,9 @@ class FlashVSRModel:
|
|||||||
pad_right = tw - sw - pad_left
|
pad_right = tw - sw - pad_left
|
||||||
|
|
||||||
processed = []
|
processed = []
|
||||||
for i in range(N):
|
for i in range(target):
|
||||||
frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
frame_idx = min(i, N - 1) # clamp to last real frame
|
||||||
|
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
||||||
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
||||||
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
|
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
|
||||||
# Centered reflect padding (matches naxci1 reference)
|
# Centered reflect padding (matches naxci1 reference)
|
||||||
@@ -783,17 +795,6 @@ class FlashVSRModel:
|
|||||||
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
||||||
|
|
||||||
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
||||||
|
|
||||||
# Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference)
|
|
||||||
num_with_pad = N + 4
|
|
||||||
target = ((num_with_pad - 1) // 8) * 8 + 1 # largest_8n1_leq
|
|
||||||
if target < 1:
|
|
||||||
target = 1
|
|
||||||
if target > N:
|
|
||||||
pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1)
|
|
||||||
video = torch.cat([video, pad], dim=2)
|
|
||||||
elif target < N:
|
|
||||||
video = video[:, :, :target, :, :]
|
|
||||||
nf = video.shape[2]
|
nf = video.shape[2]
|
||||||
|
|
||||||
return video, th, tw, nf, sh, sw, pad_top, pad_left
|
return video, th, tw, nf, sh, sw, pad_top, pad_left
|
||||||
@@ -812,7 +813,7 @@ class FlashVSRModel:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
|
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
|
||||||
topk_ratio=2.0, kv_ratio=2.0, local_range=9,
|
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
|
||||||
color_fix=True, unload_dit=False, seed=1,
|
color_fix=True, unload_dit=False, seed=1,
|
||||||
progress_bar_cmd=None):
|
progress_bar_cmd=None):
|
||||||
"""Upscale video frames with FlashVSR.
|
"""Upscale video frames with FlashVSR.
|
||||||
|
|||||||
20
nodes.py
20
nodes.py
@@ -1672,12 +1672,12 @@ class FlashVSRUpscale:
|
|||||||
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
|
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
|
||||||
}),
|
}),
|
||||||
"kv_ratio": ("FLOAT", {
|
"kv_ratio": ("FLOAT", {
|
||||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||||
"tooltip": "KV cache ratio. Higher = better quality, more VRAM.",
|
"tooltip": "KV cache ratio. Higher = better quality, more VRAM. 3.0 recommended.",
|
||||||
}),
|
}),
|
||||||
"local_range": ([9, 11], {
|
"local_range": ([9, 11], {
|
||||||
"default": 9,
|
"default": 11,
|
||||||
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability.",
|
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability (recommended).",
|
||||||
}),
|
}),
|
||||||
"color_fix": ("BOOLEAN", {
|
"color_fix": ("BOOLEAN", {
|
||||||
"default": True,
|
"default": True,
|
||||||
@@ -1731,12 +1731,14 @@ class FlashVSRUpscale:
|
|||||||
chunks.append((prev_start, last_end))
|
chunks.append((prev_start, last_end))
|
||||||
|
|
||||||
# Estimate total pipeline steps for progress bar
|
# Estimate total pipeline steps for progress bar
|
||||||
# Mirrors _prepare_video: largest_8n1_leq(N + 4)
|
# Mirrors _prepare_video two-stage padding: next_8n5(N) + 4
|
||||||
|
def _next_8n5(n, minimum=21):
|
||||||
|
return minimum if n < minimum else ((n - 5 + 7) // 8) * 8 + 5
|
||||||
|
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
for cs, ce in chunks:
|
for cs, ce in chunks:
|
||||||
n = ce - cs
|
n = ce - cs
|
||||||
num_with_pad = n + 4
|
target = _next_8n5(n) + 4 # always 8k+1
|
||||||
target = ((num_with_pad - 1) // 8) * 8 + 1
|
|
||||||
total_steps += max(1, (target - 1) // 8 - 2)
|
total_steps += max(1, (target - 1) // 8 - 2)
|
||||||
|
|
||||||
pbar = ProgressBar(total_steps)
|
pbar = ProgressBar(total_steps)
|
||||||
@@ -1826,10 +1828,10 @@ class FlashVSRSegmentUpscale:
|
|||||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||||
}),
|
}),
|
||||||
"kv_ratio": ("FLOAT", {
|
"kv_ratio": ("FLOAT", {
|
||||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||||
}),
|
}),
|
||||||
"local_range": ([9, 11], {
|
"local_range": ([9, 11], {
|
||||||
"default": 9,
|
"default": 11,
|
||||||
}),
|
}),
|
||||||
"color_fix": ("BOOLEAN", {
|
"color_fix": ("BOOLEAN", {
|
||||||
"default": True,
|
"default": True,
|
||||||
|
|||||||
Reference in New Issue
Block a user