diff --git a/selva_core/ext/synchformer/hf_src/modeling_ast.py b/selva_core/ext/synchformer/hf_src/modeling_ast.py index 09ecf22..3355bb7 100644 --- a/selva_core/ext/synchformer/hf_src/modeling_ast.py +++ b/selva_core/ext/synchformer/hf_src/modeling_ast.py @@ -28,7 +28,10 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +try: + from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +except ImportError: + from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging