Fix FlashVSR attention mask and output quality

- Use generate_draft_block_mask_refined for sparse attention mask (matches
  naxci1's generate_draft_block_mask_sage with proper half-block key scoring)
- Remove spurious repeat_interleave(2, dim=-1) from generate_draft_block_mask
  that doubled the key dimension incorrectly
- Add torch.clamp(0, 1) to _to_frames output (matches naxci1's tensor2video)
- Add .to(self.device) on LQ video slices in streaming loop for all pipelines

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 18:41:43 +01:00
parent 76dff7e573
commit 3b87652184
4 changed files with 6 additions and 7 deletions

View File

@@ -409,7 +409,7 @@ class FlashVSRFullPipeline(BasePipeline):
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
@@ -425,7 +425,7 @@ class FlashVSRFullPipeline(BasePipeline):
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue

View File

@@ -400,7 +400,7 @@ class FlashVSRTinyPipeline(BasePipeline):
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
@@ -416,7 +416,7 @@ class FlashVSRTinyPipeline(BasePipeline):
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue