Fix crash when flash_attn is installed but broken
Verify attention backend functions are actually callable before marking them available. Falls back to PyTorch SDPA instead of calling None. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,26 +11,30 @@ from .utils import hash_state_dict_keys
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
|
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
|
||||||
FLASH_ATTN_3_AVAILABLE = True
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
except ModuleNotFoundError:
|
except Exception:
|
||||||
FLASH_ATTN_3_AVAILABLE = False
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
|
assert callable(getattr(flash_attn, "flash_attn_func", None))
|
||||||
FLASH_ATTN_2_AVAILABLE = True
|
FLASH_ATTN_2_AVAILABLE = True
|
||||||
except ModuleNotFoundError:
|
except Exception:
|
||||||
FLASH_ATTN_2_AVAILABLE = False
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
|
assert callable(sageattn)
|
||||||
SAGE_ATTN_AVAILABLE = True
|
SAGE_ATTN_AVAILABLE = True
|
||||||
except ModuleNotFoundError:
|
except Exception:
|
||||||
SAGE_ATTN_AVAILABLE = False
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattn.core import sparse_sageattn
|
from sageattn.core import sparse_sageattn
|
||||||
|
assert callable(sparse_sageattn)
|
||||||
SPARSE_SAGE_AVAILABLE = True
|
SPARSE_SAGE_AVAILABLE = True
|
||||||
except ModuleNotFoundError:
|
except Exception:
|
||||||
SPARSE_SAGE_AVAILABLE = False
|
SPARSE_SAGE_AVAILABLE = False
|
||||||
sparse_sageattn = None
|
sparse_sageattn = None
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|||||||
Reference in New Issue
Block a user