diff --git a/flashvsr_arch/models/sparse_sage/__init__.py b/flashvsr_arch/models/sparse_sage/__init__.py new file mode 100644 index 0000000..95351e5 --- /dev/null +++ b/flashvsr_arch/models/sparse_sage/__init__.py @@ -0,0 +1,3 @@ +from .core import sparse_sageattn + +__all__ = ["sparse_sageattn"] diff --git a/flashvsr_arch/models/sparse_sage/core.py b/flashvsr_arch/models/sparse_sage/core.py new file mode 100644 index 0000000..a022771 --- /dev/null +++ b/flashvsr_arch/models/sparse_sage/core.py @@ -0,0 +1,40 @@ +""" +Sparse SageAttention — block-sparse INT8 attention via Triton. + +https://github.com/jt-zhang/Sparse_SageAttention_API + +Copyright (c) 2024 by SageAttention team. +Licensed under the Apache License, Version 2.0 +""" + +from .quant_per_block import per_block_int8 +from .sparse_int8_attn import forward as sparse_sageattn_fwd +import torch + + +def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"): + if mask_id is None: + mask_id = torch.ones( + (q.shape[0], q.shape[1], + (q.shape[2] + 128 - 1) // 128, + (q.shape[3] + 64 - 1) // 64), + dtype=torch.int8, device=q.device, + ) + + output_dtype = q.dtype + if output_dtype == torch.bfloat16 or output_dtype == torch.float32: + v = v.to(torch.float16) + + seq_dim = 1 if tensor_layout == "NHD" else 2 + km = k.mean(dim=seq_dim, keepdim=True) + + q_int8, q_scale, k_int8, k_scale = per_block_int8( + q, k, km=km, tensor_layout=tensor_layout, + ) + + o = sparse_sageattn_fwd( + q_int8, k_int8, mask_id, v, q_scale, k_scale, + is_causal=is_causal, tensor_layout=tensor_layout, + output_dtype=output_dtype, + ) + return o diff --git a/flashvsr_arch/models/sparse_sage/quant_per_block.py b/flashvsr_arch/models/sparse_sage/quant_per_block.py new file mode 100644 index 0000000..a1d317e --- /dev/null +++ b/flashvsr_arch/models/sparse_sage/quant_per_block.py @@ -0,0 +1,110 @@ +""" +Per-block INT8 quantization kernel for Sparse SageAttention. + +Copyright (c) 2024 by SageAttention team. +Licensed under the Apache License, Version 2.0 +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def quant_per_block_int8_kernel( + Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + sm_scale, + C: tl.constexpr, BLK: tl.constexpr, +): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127.0 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32, + ) + + if sm_scale is None: + sm_scale = head_dim ** -0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + sm_scale=(sm_scale * 1.44269504), + C=head_dim, BLK=BLKQ, + ) + + grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + sm_scale=1.0, + C=head_dim, BLK=BLKK, + ) + + return q_int8, q_scale, k_int8, k_scale diff --git a/flashvsr_arch/models/sparse_sage/sparse_int8_attn.py b/flashvsr_arch/models/sparse_sage/sparse_int8_attn.py new file mode 100644 index 0000000..65a89c4 --- /dev/null +++ b/flashvsr_arch/models/sparse_sage/sparse_int8_attn.py @@ -0,0 +1,196 @@ +""" +Sparse INT8 attention kernel for Sparse SageAttention. + +Copyright (c) 2024 by SageAttention team. +Licensed under the Apache License, Version 2.0 +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _attn_fwd_inner( + acc, l_i, old_m, q, q_scale, kv_len, + K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, + stride_kn, stride_vn, start_m, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += lo // BLOCK_N + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + elif STAGE == 3: + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + kbid = tl.load(K_bid_ptr + start_n // BLOCK_N) + if kbid: + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask=k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk + tl.where(mask, 0, -1.0e6) + local_m = tl.max(qk, 1) + new_m = tl.maximum(old_m, local_m) + qk -= new_m[:, None] + else: + local_m = tl.max(qk, 1) + new_m = tl.maximum(old_m, local_m) + qk = qk - new_m[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(old_m - new_m) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + acc += tl.dot(p, v, out_dtype=tl.float16) + old_m = new_m + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, old_m + + +@triton.jit +def _attn_fwd( + Q, K, K_blkid, V, Q_scale, K_scale, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_vz, stride_vh, stride_vn, + stride_oz, stride_oh, stride_on, + stride_kbidq, stride_kbidk, + qo_len, kv_len, + H: tl.constexpr, num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, +): + start_m = tl.program_id(0) + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = ( + off_z * (H // num_kv_groups) + off_h // num_kv_groups + ) * tl.cdiv(kv_len, BLOCK_N) + k_bid_offset = ( + off_z * (H // num_kv_groups) + off_h // num_kv_groups + ) * stride_kbidq + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = ( + Q + + (off_z * stride_qz + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = ( + K + + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + + offs_n[None, :] * stride_kn + + offs_k[:, None] + ) + K_scale_ptr = K_scale + k_scale_offset + K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk + V_ptrs = ( + V + + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + + offs_n[:, None] * stride_vn + + offs_k[None, :] + ) + O_block_ptr = ( + Out + + (off_z * stride_oz + off_h * stride_oh) + + offs_m[:, None] * stride_on + + offs_k[None, :] + ) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, q_scale, kv_len, + K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, + stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n, + ) + if STAGE != 1: + acc, l_i, _ = _attn_fwd_inner( + acc, l_i, m_i, q, q_scale, kv_len, + K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, + stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n, + ) + acc = acc / l_i[:, None] + tl.store( + O_block_ptr, + acc.to(Out.type.element_ty), + mask=(offs_m[:, None] < qo_len), + ) + + +def forward( + q, k, k_block_id, v, q_scale, k_scale, + is_causal=False, tensor_layout="HND", output_dtype=torch.float16, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 if is_causal else 1 + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + if is_causal: + assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b) + _attn_fwd[grid]( + q, k, k_block_id, v, q_scale, k_scale, o, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_v, stride_h_v, stride_seq_v, + stride_bz_o, stride_h_o, stride_seq_o, + k_block_id.stride(1), k_block_id.stride(2), + qo_len, kv_len, + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4, + ) + return o diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py index 5196bdf..0241082 100644 --- a/flashvsr_arch/models/wan_video_dit.py +++ b/flashvsr_arch/models/wan_video_dit.py @@ -31,15 +31,24 @@ except Exception: SAGE_ATTN_AVAILABLE = False try: - from sageattn.core import sparse_sageattn + from .sparse_sage.core import sparse_sageattn assert callable(sparse_sageattn) SPARSE_SAGE_AVAILABLE = True except Exception: - SPARSE_SAGE_AVAILABLE = False - sparse_sageattn = None + try: + from sageattn.core import sparse_sageattn + assert callable(sparse_sageattn) + SPARSE_SAGE_AVAILABLE = True + except Exception: + SPARSE_SAGE_AVAILABLE = False + sparse_sageattn = None from PIL import Image import numpy as np +print(f"[FlashVSR] Attention backends: sparse_sage={SPARSE_SAGE_AVAILABLE}, " + f"flash_attn_3={FLASH_ATTN_3_AVAILABLE}, flash_attn_2={FLASH_ATTN_2_AVAILABLE}, " + f"sage_attn={SAGE_ATTN_AVAILABLE}") + # ---------------------------- # Local / window masks