Fix crash when flash_attn is installed but broken

Verify attention backend functions are actually callable before marking
them available. Falls back to PyTorch SDPA instead of calling None.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 15:51:30 +01:00
parent 8317a0603e
commit f40504cbcf

View File

@@ -11,26 +11,30 @@ from .utils import hash_state_dict_keys
try: try:
import flash_attn_interface import flash_attn_interface
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
FLASH_ATTN_3_AVAILABLE = True FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError: except Exception:
FLASH_ATTN_3_AVAILABLE = False FLASH_ATTN_3_AVAILABLE = False
try: try:
import flash_attn import flash_attn
assert callable(getattr(flash_attn, "flash_attn_func", None))
FLASH_ATTN_2_AVAILABLE = True FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError: except Exception:
FLASH_ATTN_2_AVAILABLE = False FLASH_ATTN_2_AVAILABLE = False
try: try:
from sageattention import sageattn from sageattention import sageattn
assert callable(sageattn)
SAGE_ATTN_AVAILABLE = True SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError: except Exception:
SAGE_ATTN_AVAILABLE = False SAGE_ATTN_AVAILABLE = False
try: try:
from sageattn.core import sparse_sageattn from sageattn.core import sparse_sageattn
assert callable(sparse_sageattn)
SPARSE_SAGE_AVAILABLE = True SPARSE_SAGE_AVAILABLE = True
except ModuleNotFoundError: except Exception:
SPARSE_SAGE_AVAILABLE = False SPARSE_SAGE_AVAILABLE = False
sparse_sageattn = None sparse_sageattn = None
from PIL import Image from PIL import Image