From 0508868978600e2feeddffb3ce77070b9250e783 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 15 Feb 2026 01:37:10 +0100 Subject: [PATCH] =?UTF-8?q?Revert=20SDPA=20to=203D=20tensors=20=E2=80=94?= =?UTF-8?q?=204D=20unsqueeze=20caused=20quality=20degradation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 --- inference.py | 15 ++++++--------- nodes.py | 15 ++++++--------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/inference.py b/inference.py index acf54fe..4095e32 100755 --- a/inference.py +++ b/inference.py @@ -144,17 +144,14 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None): a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) outs.append(torch.bmm(a.softmax(dim=-1), vi)) return torch.cat(outs) - # SDPA and SageAttention both need 4D (batch, heads, seq, dim). - # The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to - # (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D - # path falls back to efficient_attention / math which can miscompute - # on newer Ada Lovelace GPUs). - q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) if _active_attn == "sdpa": - return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0) - # SageAttention variants + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + # SageAttention variants: need 4D tensors (batch, heads, seq, dim) fn = _ATTN_BACKENDS[_active_attn] - return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0) + return fn( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), + tensor_layout="HND", is_causal=False, + ).squeeze(0) _dispatched_mea._is_star_dispatcher = True diff --git a/nodes.py b/nodes.py index 6443992..c53724c 100644 --- a/nodes.py +++ b/nodes.py @@ -108,17 +108,14 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None): a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) outs.append(torch.bmm(a.softmax(dim=-1), vi)) return torch.cat(outs) - # SDPA and SageAttention both need 4D (batch, heads, seq, dim). - # The model sends 3D xformers-BMK tensors (B*H, N, D); unsqueeze to - # (1, B*H, N, D) so SDPA can pick the Flash Attention kernel (the 3D - # path falls back to efficient_attention / math which can miscompute - # on newer Ada Lovelace GPUs). - q4, k4, v4 = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) if _active_attn == "sdpa": - return F.scaled_dot_product_attention(q4, k4, v4, attn_mask=attn_bias).squeeze(0) - # SageAttention variants + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + # SageAttention variants: need 4D tensors (batch, heads, seq, dim) fn = _ATTN_BACKENDS[_active_attn] - return fn(q4, k4, v4, tensor_layout="HND", is_causal=False).squeeze(0) + return fn( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), + tensor_layout="HND", is_causal=False, + ).squeeze(0) _dispatched_mea._is_star_dispatcher = True