Fix attention dispatcher: use 4D tensors for SDPA, add math backend
SDPA with 3D xformers-BMK tensors cannot use Flash Attention and falls back to efficient_attention/math kernels that miscompute on Ada Lovelace GPUs (e.g. RTX 6000 Pro), producing brownish line artifacts. Unsqueeze to 4D (1, B*H, N, D) so Flash Attention is eligible. Also add a naive "math" backend (chunked bmm) as a guaranteed-correct diagnostic baseline. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
31
inference.py
31
inference.py
@@ -114,6 +114,9 @@ for _name in _SAGE_VARIANTS:
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Manual attention (guaranteed correct, used as diagnostic baseline)
|
||||
_ATTN_BACKENDS["math"] = "math"
|
||||
|
||||
_active_attn = "sdpa"
|
||||
|
||||
|
||||
@@ -127,15 +130,31 @@ def _set_attn(backend: str):
|
||||
|
||||
|
||||
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||
if _active_attn == "xformers":
|
||||
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
|
||||
if _active_attn == "math":
|
||||
# Naive batched attention — slow but guaranteed correct.
|
||||
scale = q.shape[-1] ** -0.5
|
||||
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
|
||||
outs = []
|
||||
for i in range(0, q.shape[0], cs):
|
||||
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
|
||||
a = torch.bmm(qi * scale, ki.transpose(1, 2))
|
||||
if attn_bias is not None:
|
||||
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
|
||||
outs.append(torch.bmm(a.softmax(dim=-1), vi))
|
||||
return torch.cat(outs)
|
||||
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
|
||||
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
|
||||
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
|
||||
# path falls back to efficient_attention / math which can miscompute
|
||||
# on newer Ada Lovelace GPUs).
|
||||
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
|
||||
# SageAttention variants
|
||||
fn = _ATTN_BACKENDS[_active_attn]
|
||||
return fn(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||
tensor_layout="HND", is_causal=False,
|
||||
).squeeze(0)
|
||||
return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
|
||||
|
||||
|
||||
_dispatched_mea._is_star_dispatcher = True
|
||||
|
||||
32
nodes.py
32
nodes.py
@@ -78,6 +78,9 @@ for _name in _SAGE_VARIANTS:
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Manual attention (guaranteed correct, used as diagnostic baseline)
|
||||
_ATTN_BACKENDS["math"] = "math"
|
||||
|
||||
_active_attn = "sdpa"
|
||||
|
||||
|
||||
@@ -91,16 +94,31 @@ def _set_attn(backend: str):
|
||||
|
||||
|
||||
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||
if _active_attn == "xformers":
|
||||
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
|
||||
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
|
||||
if _active_attn == "math":
|
||||
# Naive batched attention — slow but guaranteed correct.
|
||||
scale = q.shape[-1] ** -0.5
|
||||
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
|
||||
outs = []
|
||||
for i in range(0, q.shape[0], cs):
|
||||
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
|
||||
a = torch.bmm(qi * scale, ki.transpose(1, 2))
|
||||
if attn_bias is not None:
|
||||
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
|
||||
outs.append(torch.bmm(a.softmax(dim=-1), vi))
|
||||
return torch.cat(outs)
|
||||
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
|
||||
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
|
||||
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
|
||||
# path falls back to efficient_attention / math which can miscompute
|
||||
# on newer Ada Lovelace GPUs).
|
||||
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
|
||||
# SageAttention variants
|
||||
fn = _ATTN_BACKENDS[_active_attn]
|
||||
return fn(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||
tensor_layout="HND", is_causal=False,
|
||||
).squeeze(0)
|
||||
return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
|
||||
|
||||
|
||||
_dispatched_mea._is_star_dispatcher = True
|
||||
|
||||
Reference in New Issue
Block a user