Fix attention dispatcher: use 4D tensors for SDPA, add math backend

SDPA with 3D xformers-BMK tensors cannot use Flash Attention and falls
back to efficient_attention/math kernels that miscompute on Ada Lovelace
GPUs (e.g. RTX 6000 Pro), producing brownish line artifacts.  Unsqueeze
to 4D (1, B*H, N, D) so Flash Attention is eligible.  Also add a naive
"math" backend (chunked bmm) as a guaranteed-correct diagnostic baseline.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 01:05:51 +01:00
parent f991f5cb02
commit 4c6c38f05a
2 changed files with 50 additions and 13 deletions

View File

@@ -114,6 +114,9 @@ for _name in _SAGE_VARIANTS:
except (ImportError, AttributeError): except (ImportError, AttributeError):
pass pass
# Manual attention (guaranteed correct, used as diagnostic baseline)
_ATTN_BACKENDS["math"] = "math"
_active_attn = "sdpa" _active_attn = "sdpa"
@@ -127,15 +130,31 @@ def _set_attn(backend: str):
def _dispatched_mea(q, k, v, attn_bias=None, op=None): def _dispatched_mea(q, k, v, attn_bias=None, op=None):
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
if _active_attn == "xformers": if _active_attn == "xformers":
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
if _active_attn == "math":
# Naive batched attention — slow but guaranteed correct.
scale = q.shape[-1] ** -0.5
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
outs = []
for i in range(0, q.shape[0], cs):
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
a = torch.bmm(qi * scale, ki.transpose(1, 2))
if attn_bias is not None:
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs)
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
# path falls back to efficient_attention / math which can miscompute
# on newer Ada Lovelace GPUs).
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
# SageAttention variants
fn = _ATTN_BACKENDS[_active_attn] fn = _ATTN_BACKENDS[_active_attn]
return fn( return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
tensor_layout="HND", is_causal=False,
).squeeze(0)
_dispatched_mea._is_star_dispatcher = True _dispatched_mea._is_star_dispatcher = True

View File

@@ -78,6 +78,9 @@ for _name in _SAGE_VARIANTS:
except (ImportError, AttributeError): except (ImportError, AttributeError):
pass pass
# Manual attention (guaranteed correct, used as diagnostic baseline)
_ATTN_BACKENDS["math"] = "math"
_active_attn = "sdpa" _active_attn = "sdpa"
@@ -91,16 +94,31 @@ def _set_attn(backend: str):
def _dispatched_mea(q, k, v, attn_bias=None, op=None): def _dispatched_mea(q, k, v, attn_bias=None, op=None):
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
if _active_attn == "xformers": if _active_attn == "xformers":
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
# SageAttention variants: need 4D tensors (batch, heads, seq, dim) if _active_attn == "math":
# Naive batched attention — slow but guaranteed correct.
scale = q.shape[-1] ** -0.5
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
outs = []
for i in range(0, q.shape[0], cs):
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
a = torch.bmm(qi * scale, ki.transpose(1, 2))
if attn_bias is not None:
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs)
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
# path falls back to efficient_attention / math which can miscompute
# on newer Ada Lovelace GPUs).
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
# SageAttention variants
fn = _ATTN_BACKENDS[_active_attn] fn = _ATTN_BACKENDS[_active_attn]
return fn( return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
tensor_layout="HND", is_causal=False,
).squeeze(0)
_dispatched_mea._is_star_dispatcher = True _dispatched_mea._is_star_dispatcher = True