Bundle sparse_sage Triton kernel for block-sparse attention
Without sparse attention, the model uses full (dense) attention which attends to distant irrelevant information, causing ghosting artifacts. The FlashVSR paper explicitly requires block-sparse attention. Vendored from SageAttention team (Apache 2.0), pure Triton (no CUDA C++). Import chain: local sparse_sage → external sageattn.core → SDPA fallback. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
3
flashvsr_arch/models/sparse_sage/__init__.py
Normal file
3
flashvsr_arch/models/sparse_sage/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .core import sparse_sageattn
|
||||||
|
|
||||||
|
__all__ = ["sparse_sageattn"]
|
||||||
40
flashvsr_arch/models/sparse_sage/core.py
Normal file
40
flashvsr_arch/models/sparse_sage/core.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
Sparse SageAttention — block-sparse INT8 attention via Triton.
|
||||||
|
|
||||||
|
https://github.com/jt-zhang/Sparse_SageAttention_API
|
||||||
|
|
||||||
|
Copyright (c) 2024 by SageAttention team.
|
||||||
|
Licensed under the Apache License, Version 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .quant_per_block import per_block_int8
|
||||||
|
from .sparse_int8_attn import forward as sparse_sageattn_fwd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
|
||||||
|
if mask_id is None:
|
||||||
|
mask_id = torch.ones(
|
||||||
|
(q.shape[0], q.shape[1],
|
||||||
|
(q.shape[2] + 128 - 1) // 128,
|
||||||
|
(q.shape[3] + 64 - 1) // 64),
|
||||||
|
dtype=torch.int8, device=q.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_dtype = q.dtype
|
||||||
|
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
|
||||||
|
v = v.to(torch.float16)
|
||||||
|
|
||||||
|
seq_dim = 1 if tensor_layout == "NHD" else 2
|
||||||
|
km = k.mean(dim=seq_dim, keepdim=True)
|
||||||
|
|
||||||
|
q_int8, q_scale, k_int8, k_scale = per_block_int8(
|
||||||
|
q, k, km=km, tensor_layout=tensor_layout,
|
||||||
|
)
|
||||||
|
|
||||||
|
o = sparse_sageattn_fwd(
|
||||||
|
q_int8, k_int8, mask_id, v, q_scale, k_scale,
|
||||||
|
is_causal=is_causal, tensor_layout=tensor_layout,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)
|
||||||
|
return o
|
||||||
110
flashvsr_arch/models/sparse_sage/quant_per_block.py
Normal file
110
flashvsr_arch/models/sparse_sage/quant_per_block.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""
|
||||||
|
Per-block INT8 quantization kernel for Sparse SageAttention.
|
||||||
|
|
||||||
|
Copyright (c) 2024 by SageAttention team.
|
||||||
|
Licensed under the Apache License, Version 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def quant_per_block_int8_kernel(
|
||||||
|
Input, Output, Scale, L,
|
||||||
|
stride_iz, stride_ih, stride_in,
|
||||||
|
stride_oz, stride_oh, stride_on,
|
||||||
|
stride_sz, stride_sh,
|
||||||
|
sm_scale,
|
||||||
|
C: tl.constexpr, BLK: tl.constexpr,
|
||||||
|
):
|
||||||
|
off_blk = tl.program_id(0)
|
||||||
|
off_h = tl.program_id(1)
|
||||||
|
off_b = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
||||||
|
offs_k = tl.arange(0, C)
|
||||||
|
|
||||||
|
input_ptrs = (
|
||||||
|
Input
|
||||||
|
+ off_b * stride_iz
|
||||||
|
+ off_h * stride_ih
|
||||||
|
+ offs_n[:, None] * stride_in
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
output_ptrs = (
|
||||||
|
Output
|
||||||
|
+ off_b * stride_oz
|
||||||
|
+ off_h * stride_oh
|
||||||
|
+ offs_n[:, None] * stride_on
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
|
||||||
|
|
||||||
|
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
||||||
|
x = x.to(tl.float32)
|
||||||
|
x *= sm_scale
|
||||||
|
scale = tl.max(tl.abs(x)) / 127.0
|
||||||
|
x_int8 = x / scale
|
||||||
|
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
||||||
|
x_int8 = x_int8.to(tl.int8)
|
||||||
|
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
||||||
|
tl.store(scale_ptrs, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
|
||||||
|
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
||||||
|
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
||||||
|
|
||||||
|
if km is not None:
|
||||||
|
k = k - km
|
||||||
|
|
||||||
|
if tensor_layout == "HND":
|
||||||
|
b, h_qo, qo_len, head_dim = q.shape
|
||||||
|
_, h_kv, kv_len, _ = k.shape
|
||||||
|
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
||||||
|
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
|
||||||
|
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
||||||
|
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
|
||||||
|
elif tensor_layout == "NHD":
|
||||||
|
b, qo_len, h_qo, head_dim = q.shape
|
||||||
|
_, kv_len, h_kv, _ = k.shape
|
||||||
|
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
||||||
|
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
|
||||||
|
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
||||||
|
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
|
||||||
|
|
||||||
|
q_scale = torch.empty(
|
||||||
|
(b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32,
|
||||||
|
)
|
||||||
|
k_scale = torch.empty(
|
||||||
|
(b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
|
||||||
|
quant_per_block_int8_kernel[grid](
|
||||||
|
q, q_int8, q_scale, qo_len,
|
||||||
|
stride_bz_q, stride_h_q, stride_seq_q,
|
||||||
|
stride_bz_qo, stride_h_qo, stride_seq_qo,
|
||||||
|
q_scale.stride(0), q_scale.stride(1),
|
||||||
|
sm_scale=(sm_scale * 1.44269504),
|
||||||
|
C=head_dim, BLK=BLKQ,
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
|
||||||
|
quant_per_block_int8_kernel[grid](
|
||||||
|
k, k_int8, k_scale, kv_len,
|
||||||
|
stride_bz_k, stride_h_k, stride_seq_k,
|
||||||
|
stride_bz_ko, stride_h_ko, stride_seq_ko,
|
||||||
|
k_scale.stride(0), k_scale.stride(1),
|
||||||
|
sm_scale=1.0,
|
||||||
|
C=head_dim, BLK=BLKK,
|
||||||
|
)
|
||||||
|
|
||||||
|
return q_int8, q_scale, k_int8, k_scale
|
||||||
196
flashvsr_arch/models/sparse_sage/sparse_int8_attn.py
Normal file
196
flashvsr_arch/models/sparse_sage/sparse_int8_attn.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""
|
||||||
|
Sparse INT8 attention kernel for Sparse SageAttention.
|
||||||
|
|
||||||
|
Copyright (c) 2024 by SageAttention team.
|
||||||
|
Licensed under the Apache License, Version 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd_inner(
|
||||||
|
acc, l_i, old_m, q, q_scale, kv_len,
|
||||||
|
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
||||||
|
stride_kn, stride_vn, start_m,
|
||||||
|
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||||
|
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
|
||||||
|
):
|
||||||
|
if STAGE == 1:
|
||||||
|
lo, hi = 0, start_m * BLOCK_M
|
||||||
|
elif STAGE == 2:
|
||||||
|
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||||
|
lo = tl.multiple_of(lo, BLOCK_M)
|
||||||
|
K_scale_ptr += lo // BLOCK_N
|
||||||
|
K_ptrs += stride_kn * lo
|
||||||
|
V_ptrs += stride_vn * lo
|
||||||
|
elif STAGE == 3:
|
||||||
|
lo, hi = 0, kv_len
|
||||||
|
for start_n in range(lo, hi, BLOCK_N):
|
||||||
|
kbid = tl.load(K_bid_ptr + start_n // BLOCK_N)
|
||||||
|
if kbid:
|
||||||
|
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||||
|
k = tl.load(K_ptrs, mask=k_mask)
|
||||||
|
k_scale = tl.load(K_scale_ptr)
|
||||||
|
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
||||||
|
if STAGE == 2:
|
||||||
|
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||||
|
qk = qk + tl.where(mask, 0, -1.0e6)
|
||||||
|
local_m = tl.max(qk, 1)
|
||||||
|
new_m = tl.maximum(old_m, local_m)
|
||||||
|
qk -= new_m[:, None]
|
||||||
|
else:
|
||||||
|
local_m = tl.max(qk, 1)
|
||||||
|
new_m = tl.maximum(old_m, local_m)
|
||||||
|
qk = qk - new_m[:, None]
|
||||||
|
|
||||||
|
p = tl.math.exp2(qk)
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
alpha = tl.math.exp2(old_m - new_m)
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
||||||
|
p = p.to(tl.float16)
|
||||||
|
acc += tl.dot(p, v, out_dtype=tl.float16)
|
||||||
|
old_m = new_m
|
||||||
|
K_ptrs += BLOCK_N * stride_kn
|
||||||
|
K_scale_ptr += 1
|
||||||
|
V_ptrs += BLOCK_N * stride_vn
|
||||||
|
return acc, l_i, old_m
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd(
|
||||||
|
Q, K, K_blkid, V, Q_scale, K_scale, Out,
|
||||||
|
stride_qz, stride_qh, stride_qn,
|
||||||
|
stride_kz, stride_kh, stride_kn,
|
||||||
|
stride_vz, stride_vh, stride_vn,
|
||||||
|
stride_oz, stride_oh, stride_on,
|
||||||
|
stride_kbidq, stride_kbidk,
|
||||||
|
qo_len, kv_len,
|
||||||
|
H: tl.constexpr, num_kv_groups: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
STAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_z = tl.program_id(2).to(tl.int64)
|
||||||
|
off_h = tl.program_id(1).to(tl.int64)
|
||||||
|
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
|
||||||
|
k_scale_offset = (
|
||||||
|
off_z * (H // num_kv_groups) + off_h // num_kv_groups
|
||||||
|
) * tl.cdiv(kv_len, BLOCK_N)
|
||||||
|
k_bid_offset = (
|
||||||
|
off_z * (H // num_kv_groups) + off_h // num_kv_groups
|
||||||
|
) * stride_kbidq
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_k = tl.arange(0, HEAD_DIM)
|
||||||
|
Q_ptrs = (
|
||||||
|
Q
|
||||||
|
+ (off_z * stride_qz + off_h * stride_qh)
|
||||||
|
+ offs_m[:, None] * stride_qn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
Q_scale_ptr = Q_scale + q_scale_offset + start_m
|
||||||
|
K_ptrs = (
|
||||||
|
K
|
||||||
|
+ (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh)
|
||||||
|
+ offs_n[None, :] * stride_kn
|
||||||
|
+ offs_k[:, None]
|
||||||
|
)
|
||||||
|
K_scale_ptr = K_scale + k_scale_offset
|
||||||
|
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
|
||||||
|
V_ptrs = (
|
||||||
|
V
|
||||||
|
+ (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh)
|
||||||
|
+ offs_n[:, None] * stride_vn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
O_block_ptr = (
|
||||||
|
Out
|
||||||
|
+ (off_z * stride_oz + off_h * stride_oh)
|
||||||
|
+ offs_m[:, None] * stride_on
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||||
|
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||||
|
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||||
|
q_scale = tl.load(Q_scale_ptr)
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc, l_i, m_i, q, q_scale, kv_len,
|
||||||
|
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
||||||
|
stride_kn, stride_vn,
|
||||||
|
start_m,
|
||||||
|
BLOCK_M, HEAD_DIM, BLOCK_N,
|
||||||
|
4 - STAGE, offs_m, offs_n,
|
||||||
|
)
|
||||||
|
if STAGE != 1:
|
||||||
|
acc, l_i, _ = _attn_fwd_inner(
|
||||||
|
acc, l_i, m_i, q, q_scale, kv_len,
|
||||||
|
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
||||||
|
stride_kn, stride_vn,
|
||||||
|
start_m,
|
||||||
|
BLOCK_M, HEAD_DIM, BLOCK_N,
|
||||||
|
2, offs_m, offs_n,
|
||||||
|
)
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
tl.store(
|
||||||
|
O_block_ptr,
|
||||||
|
acc.to(Out.type.element_ty),
|
||||||
|
mask=(offs_m[:, None] < qo_len),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
q, k, k_block_id, v, q_scale, k_scale,
|
||||||
|
is_causal=False, tensor_layout="HND", output_dtype=torch.float16,
|
||||||
|
):
|
||||||
|
BLOCK_M = 128
|
||||||
|
BLOCK_N = 64
|
||||||
|
stage = 3 if is_causal else 1
|
||||||
|
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
||||||
|
|
||||||
|
if tensor_layout == "HND":
|
||||||
|
b, h_qo, qo_len, head_dim = q.shape
|
||||||
|
_, h_kv, kv_len, _ = k.shape
|
||||||
|
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
||||||
|
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
||||||
|
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
|
||||||
|
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
|
||||||
|
elif tensor_layout == "NHD":
|
||||||
|
b, qo_len, h_qo, head_dim = q.shape
|
||||||
|
_, kv_len, h_kv, _ = k.shape
|
||||||
|
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
||||||
|
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
||||||
|
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
|
||||||
|
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"tensor_layout {tensor_layout} not supported")
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
|
||||||
|
|
||||||
|
HEAD_DIM_K = head_dim
|
||||||
|
num_kv_groups = h_qo // h_kv
|
||||||
|
|
||||||
|
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b)
|
||||||
|
_attn_fwd[grid](
|
||||||
|
q, k, k_block_id, v, q_scale, k_scale, o,
|
||||||
|
stride_bz_q, stride_h_q, stride_seq_q,
|
||||||
|
stride_bz_k, stride_h_k, stride_seq_k,
|
||||||
|
stride_bz_v, stride_h_v, stride_seq_v,
|
||||||
|
stride_bz_o, stride_h_o, stride_seq_o,
|
||||||
|
k_block_id.stride(1), k_block_id.stride(2),
|
||||||
|
qo_len, kv_len,
|
||||||
|
h_qo, num_kv_groups,
|
||||||
|
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
|
||||||
|
STAGE=stage,
|
||||||
|
num_warps=4 if head_dim == 64 else 8,
|
||||||
|
num_stages=4,
|
||||||
|
)
|
||||||
|
return o
|
||||||
@@ -31,15 +31,24 @@ except Exception:
|
|||||||
SAGE_ATTN_AVAILABLE = False
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattn.core import sparse_sageattn
|
from .sparse_sage.core import sparse_sageattn
|
||||||
assert callable(sparse_sageattn)
|
assert callable(sparse_sageattn)
|
||||||
SPARSE_SAGE_AVAILABLE = True
|
SPARSE_SAGE_AVAILABLE = True
|
||||||
except Exception:
|
except Exception:
|
||||||
SPARSE_SAGE_AVAILABLE = False
|
try:
|
||||||
sparse_sageattn = None
|
from sageattn.core import sparse_sageattn
|
||||||
|
assert callable(sparse_sageattn)
|
||||||
|
SPARSE_SAGE_AVAILABLE = True
|
||||||
|
except Exception:
|
||||||
|
SPARSE_SAGE_AVAILABLE = False
|
||||||
|
sparse_sageattn = None
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
print(f"[FlashVSR] Attention backends: sparse_sage={SPARSE_SAGE_AVAILABLE}, "
|
||||||
|
f"flash_attn_3={FLASH_ATTN_3_AVAILABLE}, flash_attn_2={FLASH_ATTN_2_AVAILABLE}, "
|
||||||
|
f"sage_attn={SAGE_ATTN_AVAILABLE}")
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# Local / window masks
|
# Local / window masks
|
||||||
|
|||||||
Reference in New Issue
Block a user