diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py index 8500fc1..a212625 100644 --- a/flashvsr_arch/models/wan_video_dit.py +++ b/flashvsr_arch/models/wan_video_dit.py @@ -159,7 +159,6 @@ def generate_draft_block_mask(batch_size, nheads, seqlen, # 修正:上行变量名统一 # mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) - mask = mask.repeat_interleave(2, dim=-1) return mask @@ -413,7 +412,7 @@ class SelfAttention(nn.Module): self.local_attn_mask_h = h//8 self.local_attn_mask_w = w//8 self.local_range = local_range - attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) + attention_mask = generate_draft_block_mask_refined(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask) diff --git a/flashvsr_arch/pipelines/flashvsr_full.py b/flashvsr_arch/pipelines/flashvsr_full.py index 4eed010..7f3064a 100644 --- a/flashvsr_arch/pipelines/flashvsr_full.py +++ b/flashvsr_arch/pipelines/flashvsr_full.py @@ -409,7 +409,7 @@ class FlashVSRFullPipeline(BasePipeline): inner_loop_num = 7 for inner_idx in range(inner_loop_num): cur = self.denoising_model().LQ_proj_in.stream_forward( - LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :] + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device) ) if LQ_video is not None else None if cur is None: continue @@ -425,7 +425,7 @@ class FlashVSRFullPipeline(BasePipeline): inner_loop_num = 2 for inner_idx in range(inner_loop_num): cur = self.denoising_model().LQ_proj_in.stream_forward( - LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :] + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device) ) if LQ_video is not None else None if cur is None: continue diff --git a/flashvsr_arch/pipelines/flashvsr_tiny.py b/flashvsr_arch/pipelines/flashvsr_tiny.py index 14e7363..53a492a 100644 --- a/flashvsr_arch/pipelines/flashvsr_tiny.py +++ b/flashvsr_arch/pipelines/flashvsr_tiny.py @@ -400,7 +400,7 @@ class FlashVSRTinyPipeline(BasePipeline): inner_loop_num = 7 for inner_idx in range(inner_loop_num): cur = self.denoising_model().LQ_proj_in.stream_forward( - LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :] + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device) ) if LQ_video is not None else None if cur is None: continue @@ -416,7 +416,7 @@ class FlashVSRTinyPipeline(BasePipeline): inner_loop_num = 2 for inner_idx in range(inner_loop_num): cur = self.denoising_model().LQ_proj_in.stream_forward( - LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :] + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device) ) if LQ_video is not None else None if cur is None: continue diff --git a/inference.py b/inference.py index 0074820..dd3c5dd 100644 --- a/inference.py +++ b/inference.py @@ -805,7 +805,7 @@ class FlashVSRModel: from einops import rearrange v = video.squeeze(0) if video.dim() == 5 else video v = rearrange(v, "C F H W -> F H W C") - return (v.float() + 1.0) / 2.0 + return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0) # ------------------------------------------------------------------ # Main upscale method