Revert SDPA to 3D tensors — 4D unsqueeze caused quality degradation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 01:37:10 +01:00
parent f03c4853f1
commit 0508868978
2 changed files with 12 additions and 18 deletions

View File

@@ -144,17 +144,14 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None):
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs)
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
# path falls back to efficient_attention / math which can miscompute
# on newer Ada Lovelace GPUs).
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
# SageAttention variants
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
fn = _ATTN_BACKENDS[_active_attn]
return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
return fn(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
tensor_layout="HND", is_causal=False,
).squeeze(0)
_dispatched_mea._is_star_dispatcher = True

View File

@@ -108,17 +108,14 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None):
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs)
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
# path falls back to efficient_attention / math which can miscompute
# on newer Ada Lovelace GPUs).
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
# SageAttention variants
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
fn = _ATTN_BACKENDS[_active_attn]
return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
return fn(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
tensor_layout="HND", is_causal=False,
).squeeze(0)
_dispatched_mea._is_star_dispatcher = True