Fix sparse attention mask tiling for temporal windows
The local_attn_mask was not being tiled across temporal dimensions, causing assertion errors in streaming mode and wrong masks otherwise. Match naxci1 reference: 4D tile/rearrange for Q/K temporal windows, chunk-based score computation, and topk<=0 guard. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -167,50 +167,55 @@ def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
|
||||
q_w, k_w, topk=10, local_attn_mask=None):
|
||||
assert batch_size == 1, "Only batch_size=1 supported for now"
|
||||
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
||||
|
||||
avgpool_q = torch.mean(q_w, dim=1)
|
||||
|
||||
avgpool_q = torch.mean(q_w, dim=1)
|
||||
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
||||
q_heads = avgpool_q.permute(1, 0, 2)
|
||||
|
||||
D = avgpool_q.shape[-1]
|
||||
|
||||
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
|
||||
avgpool_k_split = torch.mean(k_w_split, dim=2)
|
||||
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
|
||||
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
|
||||
k_heads = avgpool_k_refined.permute(1, 0, 2)
|
||||
|
||||
D = avgpool_q.shape[-1]
|
||||
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
|
||||
|
||||
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2)
|
||||
|
||||
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
|
||||
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
|
||||
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
|
||||
scores = torch.cat([scores_1, scores_2], dim=-1)
|
||||
|
||||
repeat_head = scores.shape[0]
|
||||
num_q_blocks_local = local_attn_mask.shape[0]
|
||||
num_k_blocks_local = local_attn_mask.shape[1]
|
||||
|
||||
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
|
||||
|
||||
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
||||
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
|
||||
|
||||
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
|
||||
|
||||
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
||||
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
||||
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
|
||||
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
||||
|
||||
|
||||
local_attn_mask = local_attn_mask.to(torch.float32)
|
||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
||||
|
||||
assert scores.shape == local_attn_mask.shape
|
||||
|
||||
|
||||
assert scores.shape == local_attn_mask.shape, \
|
||||
f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
|
||||
|
||||
scores = scores + local_attn_mask
|
||||
attn_map = torch.softmax(scores, dim=-1)
|
||||
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) # it=seqlen可能需要调整,取决于seqlen的含义
|
||||
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
||||
loop_num, s1, s2 = attn_map.shape
|
||||
flat = attn_map.reshape(loop_num, -1)
|
||||
apply_topk = min(flat.shape[1]-1, topk)
|
||||
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
||||
thresholds = thresholds.unsqueeze(1)
|
||||
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
||||
|
||||
if apply_topk <= 0:
|
||||
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
|
||||
else:
|
||||
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
||||
thresholds = thresholds.unsqueeze(1)
|
||||
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
||||
|
||||
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
|
||||
|
||||
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user