Fix sageattn fallback: tensors already rearranged when exception fires
When sageattn fails, q/k/v are already in [b,n,s,d] format from the rearrange before the call. Use SDPA directly on them instead of calling _sdpa_fallback which expects [b,s,(n*d)] and crashes with a shape error. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -246,7 +246,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
||||
except Exception:
|
||||
SPARSE_SAGE_AVAILABLE = False
|
||||
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
# q,k,v already rearranged to [b, n, s, d] above
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
elif compatibility_mode:
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
elif FLASH_ATTN_3_AVAILABLE:
|
||||
@@ -273,7 +275,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
||||
except Exception:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
# q,k,v already rearranged to [b, n, s, d] above
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user