diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py index 6bdb9d0..3b71b47 100644 --- a/flashvsr_arch/models/wan_video_dit.py +++ b/flashvsr_arch/models/wan_video_dit.py @@ -218,27 +218,37 @@ def generate_draft_block_mask_refined(batch_size, nheads, seqlen, # ---------------------------- # Attention kernels # ---------------------------- +def _sdpa_fallback(q, k, v, num_heads): + """PyTorch scaled dot-product attention (always available).""" + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + return rearrange(x, "b n s d -> b s (n d)", n=num_heads) + + def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True): + global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE + if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE: - seqlen = q.shape[1] - seqlen_kv = k.shape[1] - q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) - k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) - v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) - base_blockmask = attention_mask - x = sparse_sageattn( - q, k, v, - mask_id=base_blockmask.to(torch.int8), - is_causal=False, - tensor_layout="HND" - ) - x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + try: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + base_blockmask = attention_mask + x = sparse_sageattn( + q, k, v, + mask_id=base_blockmask.to(torch.int8), + is_causal=False, + tensor_layout="HND" + ) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + except Exception: + SPARSE_SAGE_AVAILABLE = False + print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA") + x = _sdpa_fallback(q, k, v, num_heads) elif compatibility_mode: - q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) - k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) - v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) - x = F.scaled_dot_product_attention(q, k, v) - x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + x = _sdpa_fallback(q, k, v, num_heads) elif FLASH_ATTN_3_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) @@ -254,17 +264,18 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads x = flash_attn.flash_attn_func(q, k, v) x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif SAGE_ATTN_AVAILABLE: - q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) - k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) - v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) - x = sageattn(q, k, v) - x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + try: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + except Exception: + SAGE_ATTN_AVAILABLE = False + print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA") + x = _sdpa_fallback(q, k, v, num_heads) else: - q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) - k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) - v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) - x = F.scaled_dot_product_attention(q, k, v) - x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + x = _sdpa_fallback(q, k, v, num_heads) return x