Revert SDPA to 3D tensors — 4D unsqueeze caused quality degradation
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
15
inference.py
15
inference.py
@@ -144,17 +144,14 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
|
||||
outs.append(torch.bmm(a.softmax(dim=-1), vi))
|
||||
return torch.cat(outs)
|
||||
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
|
||||
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
|
||||
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
|
||||
# path falls back to efficient_attention / math which can miscompute
|
||||
# on newer Ada Lovelace GPUs).
|
||||
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
|
||||
# SageAttention variants
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
|
||||
fn = _ATTN_BACKENDS[_active_attn]
|
||||
return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
|
||||
return fn(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||
tensor_layout="HND", is_causal=False,
|
||||
).squeeze(0)
|
||||
|
||||
|
||||
_dispatched_mea._is_star_dispatcher = True
|
||||
|
||||
15
nodes.py
15
nodes.py
@@ -108,17 +108,14 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
|
||||
outs.append(torch.bmm(a.softmax(dim=-1), vi))
|
||||
return torch.cat(outs)
|
||||
# SDPA and SageAttention both need 4D (batch, heads, seq, dim).
|
||||
# The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to
|
||||
# (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D
|
||||
# path falls back to efficient_attention / math which can miscompute
|
||||
# on newer Ada Lovelace GPUs).
|
||||
q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0)
|
||||
# SageAttention variants
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
|
||||
fn = _ATTN_BACKENDS[_active_attn]
|
||||
return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0)
|
||||
return fn(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||
tensor_layout="HND", is_causal=False,
|
||||
).squeeze(0)
|
||||
|
||||
|
||||
_dispatched_mea._is_star_dispatcher = True
|
||||
|
||||
Reference in New Issue
Block a user