Fix sageattn crash on Blackwell GPUs (sm_120)
SageAttention CUDA kernels don't support Blackwell yet. Catch runtime failures from sageattn/sparse_sageattn, disable them, and fall back to PyTorch SDPA. Only pays the try/except cost once per session. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -218,27 +218,37 @@ def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
|
|||||||
# ----------------------------
|
# ----------------------------
|
||||||
# Attention kernels
|
# Attention kernels
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
def _sdpa_fallback(q, k, v, num_heads):
|
||||||
|
"""PyTorch scaled dot-product attention (always available)."""
|
||||||
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
|
||||||
|
|
||||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
||||||
|
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
|
||||||
|
|
||||||
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
||||||
seqlen = q.shape[1]
|
try:
|
||||||
seqlen_kv = k.shape[1]
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
base_blockmask = attention_mask
|
||||||
base_blockmask = attention_mask
|
x = sparse_sageattn(
|
||||||
x = sparse_sageattn(
|
q, k, v,
|
||||||
q, k, v,
|
mask_id=base_blockmask.to(torch.int8),
|
||||||
mask_id=base_blockmask.to(torch.int8),
|
is_causal=False,
|
||||||
is_causal=False,
|
tensor_layout="HND"
|
||||||
tensor_layout="HND"
|
)
|
||||||
)
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
except Exception:
|
||||||
|
SPARSE_SAGE_AVAILABLE = False
|
||||||
|
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||||
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
elif compatibility_mode:
|
elif compatibility_mode:
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
elif FLASH_ATTN_3_AVAILABLE:
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||||
@@ -254,17 +264,18 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
|||||||
x = flash_attn.flash_attn_func(q, k, v)
|
x = flash_attn.flash_attn_func(q, k, v)
|
||||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||||
elif SAGE_ATTN_AVAILABLE:
|
elif SAGE_ATTN_AVAILABLE:
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
try:
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
x = sageattn(q, k, v)
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
x = sageattn(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
except Exception:
|
||||||
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||||
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
else:
|
else:
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user