From 2bf8db4f07df2acf861e1d26c7afe1c1b1ef735e Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 15 Feb 2026 01:47:10 +0100 Subject: [PATCH] Use fp32 accumulation in SDPA and math attention to match xformers precision Co-Authored-By: Claude Opus 4.6 --- inference.py | 19 +++++++++++++------ nodes.py | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/inference.py b/inference.py index 4095e32..1950141 100755 --- a/inference.py +++ b/inference.py @@ -133,19 +133,26 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None): if _active_attn == "xformers": return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) if _active_attn == "math": - # Naive batched attention — slow but guaranteed correct. + # Naive batched attention with fp32 accumulation (matches xformers). + orig_dtype = q.dtype scale = q.shape[-1] ** -0.5 - cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1))) + cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * 4)) outs = [] for i in range(0, q.shape[0], cs): - qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs] + qi, ki, vi = q[i:i+cs].float(), k[i:i+cs].float(), v[i:i+cs].float() a = torch.bmm(qi * scale, ki.transpose(1, 2)) if attn_bias is not None: - a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) + bias = attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias + a = a + bias.float() outs.append(torch.bmm(a.softmax(dim=-1), vi)) - return torch.cat(outs) + return torch.cat(outs).to(orig_dtype) if _active_attn == "sdpa": - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + # xformers uses fp32 accumulation internally; match that by computing + # in fp32 and casting back. + orig_dtype = q.dtype + bias = attn_bias.float() if attn_bias is not None else None + out = F.scaled_dot_product_attention(q.float(), k.float(), v.float(), attn_mask=bias) + return out.to(orig_dtype) # SageAttention variants: need 4D tensors (batch, heads, seq, dim) fn = _ATTN_BACKENDS[_active_attn] return fn( diff --git a/nodes.py b/nodes.py index c53724c..265cf55 100644 --- a/nodes.py +++ b/nodes.py @@ -97,19 +97,26 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None): if _active_attn == "xformers": return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) if _active_attn == "math": - # Naive batched attention — slow but guaranteed correct. + # Naive batched attention with fp32 accumulation (matches xformers). + orig_dtype = q.dtype scale = q.shape[-1] ** -0.5 - cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1))) + cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * 4)) outs = [] for i in range(0, q.shape[0], cs): - qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs] + qi, ki, vi = q[i:i+cs].float(), k[i:i+cs].float(), v[i:i+cs].float() a = torch.bmm(qi * scale, ki.transpose(1, 2)) if attn_bias is not None: - a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) + bias = attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias + a = a + bias.float() outs.append(torch.bmm(a.softmax(dim=-1), vi)) - return torch.cat(outs) + return torch.cat(outs).to(orig_dtype) if _active_attn == "sdpa": - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + # xformers uses fp32 accumulation internally; match that by computing + # in fp32 and casting back. + orig_dtype = q.dtype + bias = attn_bias.float() if attn_bias is not None else None + out = F.scaled_dot_product_attention(q.float(), k.float(), v.float(), attn_mask=bias) + return out.to(orig_dtype) # SageAttention variants: need 4D tensors (batch, heads, seq, dim) fn = _ATTN_BACKENDS[_active_attn] return fn(