diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py index 3b71b47..9aa97fd 100644 --- a/flashvsr_arch/models/wan_video_dit.py +++ b/flashvsr_arch/models/wan_video_dit.py @@ -246,7 +246,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads except Exception: SPARSE_SAGE_AVAILABLE = False print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA") - x = _sdpa_fallback(q, k, v, num_heads) + # q,k,v already rearranged to [b, n, s, d] above + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) elif compatibility_mode: x = _sdpa_fallback(q, k, v, num_heads) elif FLASH_ATTN_3_AVAILABLE: @@ -273,7 +275,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads except Exception: SAGE_ATTN_AVAILABLE = False print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA") - x = _sdpa_fallback(q, k, v, num_heads) + # q,k,v already rearranged to [b, n, s, d] above + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) else: x = _sdpa_fallback(q, k, v, num_heads) return x