Use fp32 accumulation in SDPA and math attention to match xformers precision

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 01:47:10 +01:00
parent 0508868978
commit 2bf8db4f07
2 changed files with 26 additions and 12 deletions

View File

@@ -133,19 +133,26 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None):
if _active_attn == "xformers":
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
if _active_attn == "math":
# Naive batched attention — slow but guaranteed correct.
# Naive batched attention with fp32 accumulation (matches xformers).
orig_dtype = q.dtype
scale = q.shape[-1] ** -0.5
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * 4))
outs = []
for i in range(0, q.shape[0], cs):
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
qi, ki, vi = q[i:i+cs].float(), k[i:i+cs].float(), v[i:i+cs].float()
a = torch.bmm(qi * scale, ki.transpose(1, 2))
if attn_bias is not None:
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
bias = attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias
a = a + bias.float()
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs)
return torch.cat(outs).to(orig_dtype)
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# xformers uses fp32 accumulation internally; match that by computing
# in fp32 and casting back.
orig_dtype = q.dtype
bias = attn_bias.float() if attn_bias is not None else None
out = F.scaled_dot_product_attention(q.float(), k.float(), v.float(), attn_mask=bias)
return out.to(orig_dtype)
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
fn = _ATTN_BACKENDS[_active_attn]
return fn(