diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py index cc06947..6bdb9d0 100644 --- a/flashvsr_arch/models/wan_video_dit.py +++ b/flashvsr_arch/models/wan_video_dit.py @@ -11,26 +11,30 @@ from .utils import hash_state_dict_keys try: import flash_attn_interface + assert callable(getattr(flash_attn_interface, "flash_attn_func", None)) FLASH_ATTN_3_AVAILABLE = True -except ModuleNotFoundError: +except Exception: FLASH_ATTN_3_AVAILABLE = False try: import flash_attn + assert callable(getattr(flash_attn, "flash_attn_func", None)) FLASH_ATTN_2_AVAILABLE = True -except ModuleNotFoundError: +except Exception: FLASH_ATTN_2_AVAILABLE = False try: from sageattention import sageattn + assert callable(sageattn) SAGE_ATTN_AVAILABLE = True -except ModuleNotFoundError: +except Exception: SAGE_ATTN_AVAILABLE = False try: from sageattn.core import sparse_sageattn + assert callable(sparse_sageattn) SPARSE_SAGE_AVAILABLE = True -except ModuleNotFoundError: +except Exception: SPARSE_SAGE_AVAILABLE = False sparse_sageattn = None from PIL import Image