Fix sageattn fallback: tensors already rearranged when exception fires

When sageattn fails, q/k/v are already in [b,n,s,d] format from the
rearrange before the call. Use SDPA directly on them instead of calling
_sdpa_fallback which expects [b,s,(n*d)] and crashes with a shape error.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 16:08:01 +01:00
parent dd69a2fd2b
commit 5071c4de4f

View File

@@ -246,7 +246,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
except Exception: except Exception:
SPARSE_SAGE_AVAILABLE = False SPARSE_SAGE_AVAILABLE = False
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA") print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
x = _sdpa_fallback(q, k, v, num_heads) # q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode: elif compatibility_mode:
x = _sdpa_fallback(q, k, v, num_heads) x = _sdpa_fallback(q, k, v, num_heads)
elif FLASH_ATTN_3_AVAILABLE: elif FLASH_ATTN_3_AVAILABLE:
@@ -273,7 +275,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
except Exception: except Exception:
SAGE_ATTN_AVAILABLE = False SAGE_ATTN_AVAILABLE = False
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA") print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
x = _sdpa_fallback(q, k, v, num_heads) # q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else: else:
x = _sdpa_fallback(q, k, v, num_heads) x = _sdpa_fallback(q, k, v, num_heads)
return x return x