From 5071c4de4f7f99229bd3368af060df8ea3c2a1cf Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 13 Feb 2026 16:08:01 +0100 Subject: [PATCH] Fix sageattn fallback: tensors already rearranged when exception fires When sageattn fails, q/k/v are already in [b,n,s,d] format from the rearrange before the call. Use SDPA directly on them instead of calling _sdpa_fallback which expects [b,s,(n*d)] and crashes with a shape error. Co-Authored-By: Claude Opus 4.6 --- flashvsr_arch/models/wan_video_dit.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py index 3b71b47..9aa97fd 100644 --- a/flashvsr_arch/models/wan_video_dit.py +++ b/flashvsr_arch/models/wan_video_dit.py @@ -246,7 +246,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads except Exception: SPARSE_SAGE_AVAILABLE = False print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA") - x = _sdpa_fallback(q, k, v, num_heads) + # q,k,v already rearranged to [b, n, s, d] above + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) elif compatibility_mode: x = _sdpa_fallback(q, k, v, num_heads) elif FLASH_ATTN_3_AVAILABLE: @@ -273,7 +275,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads except Exception: SAGE_ATTN_AVAILABLE = False print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA") - x = _sdpa_fallback(q, k, v, num_heads) + # q,k,v already rearranged to [b, n, s, d] above + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) else: x = _sdpa_fallback(q, k, v, num_heads) return x