Files
ComfyUI-Tween/flashvsr_arch/models/sparse_sage/core.py
Ethanfel dd61ae8d1f Bundle sparse_sage Triton kernel for block-sparse attention
Without sparse attention, the model uses full (dense) attention which
attends to distant irrelevant information, causing ghosting artifacts.
The FlashVSR paper explicitly requires block-sparse attention.

Vendored from SageAttention team (Apache 2.0), pure Triton (no CUDA C++).
Import chain: local sparse_sage → external sageattn.core → SDPA fallback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 19:22:40 +01:00

41 lines
1.2 KiB
Python

"""
Sparse SageAttention — block-sparse INT8 attention via Triton.
https://github.com/jt-zhang/Sparse_SageAttention_API
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
from .quant_per_block import per_block_int8
from .sparse_int8_attn import forward as sparse_sageattn_fwd
import torch
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
if mask_id is None:
mask_id = torch.ones(
(q.shape[0], q.shape[1],
(q.shape[2] + 128 - 1) // 128,
(q.shape[3] + 64 - 1) // 64),
dtype=torch.int8, device=q.device,
)
output_dtype = q.dtype
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
v = v.to(torch.float16)
seq_dim = 1 if tensor_layout == "NHD" else 2
km = k.mean(dim=seq_dim, keepdim=True)
q_int8, q_scale, k_int8, k_scale = per_block_int8(
q, k, km=km, tensor_layout=tensor_layout,
)
o = sparse_sageattn_fwd(
q_int8, k_int8, mask_id, v, q_scale, k_scale,
is_causal=is_causal, tensor_layout=tensor_layout,
output_dtype=output_dtype,
)
return o