Fix sageattn crash on Blackwell GPUs (sm_120)
SageAttention CUDA kernels don't support Blackwell yet. Catch runtime failures from sageattn/sparse_sageattn, disable them, and fall back to PyTorch SDPA. Only pays the try/except cost once per session. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -218,10 +218,20 @@ def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
|
|||||||
# ----------------------------
|
# ----------------------------
|
||||||
# Attention kernels
|
# Attention kernels
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
def _sdpa_fallback(q, k, v, num_heads):
|
||||||
|
"""PyTorch scaled dot-product attention (always available)."""
|
||||||
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
|
||||||
|
|
||||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
||||||
|
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
|
||||||
|
|
||||||
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
||||||
seqlen = q.shape[1]
|
try:
|
||||||
seqlen_kv = k.shape[1]
|
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
@@ -233,12 +243,12 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
|||||||
tensor_layout="HND"
|
tensor_layout="HND"
|
||||||
)
|
)
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
except Exception:
|
||||||
|
SPARSE_SAGE_AVAILABLE = False
|
||||||
|
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||||
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
elif compatibility_mode:
|
elif compatibility_mode:
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
elif FLASH_ATTN_3_AVAILABLE:
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||||
@@ -254,17 +264,18 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
|||||||
x = flash_attn.flash_attn_func(q, k, v)
|
x = flash_attn.flash_attn_func(q, k, v)
|
||||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||||
elif SAGE_ATTN_AVAILABLE:
|
elif SAGE_ATTN_AVAILABLE:
|
||||||
|
try:
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
x = sageattn(q, k, v)
|
x = sageattn(q, k, v)
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
except Exception:
|
||||||
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||||
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
else:
|
else:
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user