diff --git a/selva_core/ext/synchformer/hf_src/modeling_ast.py b/selva_core/ext/synchformer/hf_src/modeling_ast.py index 3355bb7..7ee6385 100644 --- a/selva_core/ext/synchformer/hf_src/modeling_ast.py +++ b/selva_core/ext/synchformer/hf_src/modeling_ast.py @@ -31,7 +31,32 @@ from transformers.modeling_utils import PreTrainedModel try: from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer except ImportError: - from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer + try: + from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer + except ImportError: + # Removed in newer transformers; prune_heads() is never called at inference time. + import torch + def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads): + mask = torch.ones(n_heads, head_size) + heads = set(heads) - already_pruned_heads + for head in heads: + head -= sum(1 if h < head else 0 for h in already_pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + return heads, index + + def prune_linear_layer(layer, index, dim=0): + import torch.nn as nn + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + new = nn.Linear(W.shape[1], W.shape[0], bias=layer.bias is not None, + device=layer.weight.device, dtype=layer.weight.dtype) + new.weight = nn.Parameter(W) + if layer.bias is not None: + b = layer.bias.index_select(0, index).clone().detach() if dim == 0 else layer.bias.clone().detach() + new.bias = nn.Parameter(b) + return new from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging