fix: inline prune helpers when removed from both transformers locations

find_pruneable_heads_and_indices and prune_linear_layer were removed from
both pytorch_utils and modeling_utils in some transformers builds. Provide
minimal inline implementations as final fallback — prune_heads() is never
called at inference time so correctness is only needed for completeness.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 16:30:58 +02:00
parent ab8e1e5b7b
commit 4da4858e4a
@@ -31,7 +31,32 @@ from transformers.modeling_utils import PreTrainedModel
try: try:
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
except ImportError: except ImportError:
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer try:
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
except ImportError:
# Removed in newer transformers; prune_heads() is never called at inference time.
import torch
def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
return heads, index
def prune_linear_layer(layer, index, dim=0):
import torch.nn as nn
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
new = nn.Linear(W.shape[1], W.shape[0], bias=layer.bias is not None,
device=layer.weight.device, dtype=layer.weight.dtype)
new.weight = nn.Parameter(W)
if layer.bias is not None:
b = layer.bias.index_select(0, index).clone().detach() if dim == 0 else layer.bias.clone().detach()
new.bias = nn.Parameter(b)
return new
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging