diff --git a/inference.py b/inference.py index 2d242be..8cd60fa 100755 --- a/inference.py +++ b/inference.py @@ -114,6 +114,9 @@ for _name in _SAGE_VARIANTS: except (ImportError, AttributeError): pass +# Manual attention (guaranteed correct, used as diagnostic baseline) +_ATTN_BACKENDS["math"] = "math" + _active_attn = "sdpa" @@ -127,15 +130,31 @@ def _set_attn(backend: str): def _dispatched_mea(q, k, v, attn_bias=None, op=None): - if _active_attn == "sdpa": - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) if _active_attn == "xformers": return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) + if _active_attn == "math": + # Naive batched attention — slow but guaranteed correct. + scale = q.shape[-1] ** -0.5 + cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1))) + outs = [] + for i in range(0, q.shape[0], cs): + qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs] + a = torch.bmm(qi * scale, ki.transpose(1, 2)) + if attn_bias is not None: + a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) + outs.append(torch.bmm(a.softmax(dim=-1), vi)) + return torch.cat(outs) + # SDPA and SageAttention both need 4D (batch, heads, seq, dim). + # The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to + # (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D + # path falls back to efficient_attention / math which can miscompute + # on newer Ada Lovelace GPUs). + q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + if _active_attn == "sdpa": + return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0) + # SageAttention variants fn = _ATTN_BACKENDS[_active_attn] - return fn( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), - tensor_layout="HND", is_causal=False, - ).squeeze(0) + return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0) _dispatched_mea._is_star_dispatcher = True diff --git a/nodes.py b/nodes.py index 70b36e7..0cc4f2f 100644 --- a/nodes.py +++ b/nodes.py @@ -78,6 +78,9 @@ for _name in _SAGE_VARIANTS: except (ImportError, AttributeError): pass +# Manual attention (guaranteed correct, used as diagnostic baseline) +_ATTN_BACKENDS["math"] = "math" + _active_attn = "sdpa" @@ -91,16 +94,31 @@ def _set_attn(backend: str): def _dispatched_mea(q, k, v, attn_bias=None, op=None): - if _active_attn == "sdpa": - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) if _active_attn == "xformers": return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) - # SageAttention variants: need 4D tensors (batch, heads, seq, dim) + if _active_attn == "math": + # Naive batched attention — slow but guaranteed correct. + scale = q.shape[-1] ** -0.5 + cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1))) + outs = [] + for i in range(0, q.shape[0], cs): + qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs] + a = torch.bmm(qi * scale, ki.transpose(1, 2)) + if attn_bias is not None: + a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) + outs.append(torch.bmm(a.softmax(dim=-1), vi)) + return torch.cat(outs) + # SDPA and SageAttention both need 4D (batch, heads, seq, dim). + # The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to + # (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D + # path falls back to efficient_attention / math which can miscompute + # on newer Ada Lovelace GPUs). + q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + if _active_attn == "sdpa": + return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0) + # SageAttention variants fn = _ATTN_BACKENDS[_active_attn] - return fn( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), - tensor_layout="HND", is_causal=False, - ).squeeze(0) + return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0) _dispatched_mea._is_star_dispatcher = True