Fix FlashVSR attention mask and output quality
- Use generate_draft_block_mask_refined for sparse attention mask (matches naxci1's generate_draft_block_mask_sage with proper half-block key scoring) - Remove spurious repeat_interleave(2, dim=-1) from generate_draft_block_mask that doubled the key dimension incorrectly - Add torch.clamp(0, 1) to _to_frames output (matches naxci1's tensor2video) - Add .to(self.device) on LQ video slices in streaming loop for all pipelines Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -159,7 +159,6 @@ def generate_draft_block_mask(batch_size, nheads, seqlen,
|
|||||||
# 修正:上行变量名统一
|
# 修正:上行变量名统一
|
||||||
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
|
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
|
||||||
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
||||||
mask = mask.repeat_interleave(2, dim=-1)
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
@@ -413,7 +412,7 @@ class SelfAttention(nn.Module):
|
|||||||
self.local_attn_mask_h = h//8
|
self.local_attn_mask_h = h//8
|
||||||
self.local_attn_mask_w = w//8
|
self.local_attn_mask_w = w//8
|
||||||
self.local_range = local_range
|
self.local_range = local_range
|
||||||
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
|
attention_mask = generate_draft_block_mask_refined(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
|
||||||
|
|
||||||
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
|
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
|
||||||
|
|
||||||
|
|||||||
@@ -409,7 +409,7 @@ class FlashVSRFullPipeline(BasePipeline):
|
|||||||
inner_loop_num = 7
|
inner_loop_num = 7
|
||||||
for inner_idx in range(inner_loop_num):
|
for inner_idx in range(inner_loop_num):
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
|
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
||||||
) if LQ_video is not None else None
|
) if LQ_video is not None else None
|
||||||
if cur is None:
|
if cur is None:
|
||||||
continue
|
continue
|
||||||
@@ -425,7 +425,7 @@ class FlashVSRFullPipeline(BasePipeline):
|
|||||||
inner_loop_num = 2
|
inner_loop_num = 2
|
||||||
for inner_idx in range(inner_loop_num):
|
for inner_idx in range(inner_loop_num):
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
|
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
||||||
) if LQ_video is not None else None
|
) if LQ_video is not None else None
|
||||||
if cur is None:
|
if cur is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ class FlashVSRTinyPipeline(BasePipeline):
|
|||||||
inner_loop_num = 7
|
inner_loop_num = 7
|
||||||
for inner_idx in range(inner_loop_num):
|
for inner_idx in range(inner_loop_num):
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
|
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
||||||
) if LQ_video is not None else None
|
) if LQ_video is not None else None
|
||||||
if cur is None:
|
if cur is None:
|
||||||
continue
|
continue
|
||||||
@@ -416,7 +416,7 @@ class FlashVSRTinyPipeline(BasePipeline):
|
|||||||
inner_loop_num = 2
|
inner_loop_num = 2
|
||||||
for inner_idx in range(inner_loop_num):
|
for inner_idx in range(inner_loop_num):
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
|
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
||||||
) if LQ_video is not None else None
|
) if LQ_video is not None else None
|
||||||
if cur is None:
|
if cur is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -805,7 +805,7 @@ class FlashVSRModel:
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
v = video.squeeze(0) if video.dim() == 5 else video
|
v = video.squeeze(0) if video.dim() == 5 else video
|
||||||
v = rearrange(v, "C F H W -> F H W C")
|
v = rearrange(v, "C F H W -> F H W C")
|
||||||
return (v.float() + 1.0) / 2.0
|
return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Main upscale method
|
# Main upscale method
|
||||||
|
|||||||
Reference in New Issue
Block a user