fix: inline prune helpers when removed from both transformers locations
find_pruneable_heads_and_indices and prune_linear_layer were removed from both pytorch_utils and modeling_utils in some transformers builds. Provide minimal inline implementations as final fallback — prune_heads() is never called at inference time so correctness is only needed for completeness. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -31,7 +31,32 @@ from transformers.modeling_utils import PreTrainedModel
|
||||
try:
|
||||
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
except ImportError:
|
||||
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
try:
|
||||
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
except ImportError:
|
||||
# Removed in newer transformers; prune_heads() is never called at inference time.
|
||||
import torch
|
||||
def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
|
||||
mask = torch.ones(n_heads, head_size)
|
||||
heads = set(heads) - already_pruned_heads
|
||||
for head in heads:
|
||||
head -= sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
return heads, index
|
||||
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
import torch.nn as nn
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
new = nn.Linear(W.shape[1], W.shape[0], bias=layer.bias is not None,
|
||||
device=layer.weight.device, dtype=layer.weight.dtype)
|
||||
new.weight = nn.Parameter(W)
|
||||
if layer.bias is not None:
|
||||
b = layer.bias.index_select(0, index).clone().detach() if dim == 0 else layer.bias.clone().detach()
|
||||
new.bias = nn.Parameter(b)
|
||||
return new
|
||||
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
|
||||
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
|
||||
|
||||
Reference in New Issue
Block a user