Fix sageattn fallback: tensors already rearranged when exception fires
When sageattn fails, q/k/v are already in [b,n,s,d] format from the rearrange before the call. Use SDPA directly on them instead of calling _sdpa_fallback which expects [b,s,(n*d)] and crashes with a shape error. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -246,7 +246,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
|||||||
except Exception:
|
except Exception:
|
||||||
SPARSE_SAGE_AVAILABLE = False
|
SPARSE_SAGE_AVAILABLE = False
|
||||||
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||||
x = _sdpa_fallback(q, k, v, num_heads)
|
# q,k,v already rearranged to [b, n, s, d] above
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
elif compatibility_mode:
|
elif compatibility_mode:
|
||||||
x = _sdpa_fallback(q, k, v, num_heads)
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
elif FLASH_ATTN_3_AVAILABLE:
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
@@ -273,7 +275,9 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
|
|||||||
except Exception:
|
except Exception:
|
||||||
SAGE_ATTN_AVAILABLE = False
|
SAGE_ATTN_AVAILABLE = False
|
||||||
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||||
x = _sdpa_fallback(q, k, v, num_heads)
|
# q,k,v already rearranged to [b, n, s, d] above
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
else:
|
else:
|
||||||
x = _sdpa_fallback(q, k, v, num_heads)
|
x = _sdpa_fallback(q, k, v, num_heads)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user