fix: inline prune helpers when removed from both transformers locations
find_pruneable_heads_and_indices and prune_linear_layer were removed from both pytorch_utils and modeling_utils in some transformers builds. Provide minimal inline implementations as final fallback — prune_heads() is never called at inference time so correctness is only needed for completeness. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -31,7 +31,32 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
try:
|
try:
|
||||||
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
try:
|
||||||
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
except ImportError:
|
||||||
|
# Removed in newer transformers; prune_heads() is never called at inference time.
|
||||||
|
import torch
|
||||||
|
def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
|
||||||
|
mask = torch.ones(n_heads, head_size)
|
||||||
|
heads = set(heads) - already_pruned_heads
|
||||||
|
for head in heads:
|
||||||
|
head -= sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||||
|
mask[head] = 0
|
||||||
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
|
index = torch.arange(len(mask))[mask].long()
|
||||||
|
return heads, index
|
||||||
|
|
||||||
|
def prune_linear_layer(layer, index, dim=0):
|
||||||
|
import torch.nn as nn
|
||||||
|
index = index.to(layer.weight.device)
|
||||||
|
W = layer.weight.index_select(dim, index).clone().detach()
|
||||||
|
new = nn.Linear(W.shape[1], W.shape[0], bias=layer.bias is not None,
|
||||||
|
device=layer.weight.device, dtype=layer.weight.dtype)
|
||||||
|
new.weight = nn.Parameter(W)
|
||||||
|
if layer.bias is not None:
|
||||||
|
b = layer.bias.index_select(0, index).clone().detach() if dim == 0 else layer.bias.clone().detach()
|
||||||
|
new.bias = nn.Parameter(b)
|
||||||
|
return new
|
||||||
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
|
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
|
||||||
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user