Fix sageattn crash on Blackwell GPUs (sm_120)
SageAttention CUDA kernels don't support Blackwell yet. Catch runtime failures from sageattn/sparse_sageattn, disable them, and fall back to PyTorch SDPA. Only pays the try/except cost once per session. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -218,27 +218,37 @@ def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
|
||||
# ----------------------------
|
||||
# Attention kernels
|
||||
# ----------------------------
|
||||
def _sdpa_fallback(q, k, v, num_heads):
|
||||
"""PyTorch scaled dot-product attention (always available)."""
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
|
||||
|
||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
||||
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
|
||||
|
||||
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
||||
seqlen = q.shape[1]
|
||||
seqlen_kv = k.shape[1]
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
base_blockmask = attention_mask
|
||||
x = sparse_sageattn(
|
||||
q, k, v,
|
||||
mask_id=base_blockmask.to(torch.int8),
|
||||
is_causal=False,
|
||||
tensor_layout="HND"
|
||||
)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
try:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
base_blockmask = attention_mask
|
||||
x = sparse_sageattn(
|
||||
q, k, v,
|
||||
mask_id=base_blockmask.to(torch.int8),
|
||||
is_causal=False,
|
||||
tensor_layout="HND"
|
||||
)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
except Exception:
|
||||
SPARSE_SAGE_AVAILABLE = False
|
||||
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
elif compatibility_mode:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
elif FLASH_ATTN_3_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||
@@ -254,17 +264,18 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
||||
x = flash_attn.flash_attn_func(q, k, v)
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
elif SAGE_ATTN_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = sageattn(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
try:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = sageattn(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
except Exception:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
else:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user