chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Vladimir Iashin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1 @@
|
||||
from selva_core.ext.synchformer.synchformer import Synchformer
|
||||
@@ -0,0 +1,279 @@
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
# importing modified version of AST
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
||||
|
||||
from selva_core.ext.synchformer.hf_src.modeling_ast import ASTForAudioClassification, ASTConfig
|
||||
from selva_core.ext.synchformer.motionformer import (AveragePooling, BaseEncoderLayer,
|
||||
TemporalTransformerEncoderLayer)
|
||||
from selva_core.ext.synchformer.utils import check_if_file_exists_else_download
|
||||
|
||||
|
||||
class AST(torch.nn.Module):
|
||||
def __init__(self,
|
||||
extract_features: bool = False,
|
||||
ckpt_path: str = None,
|
||||
feat_type: str = None,
|
||||
max_spec_t: int = None,
|
||||
factorize_freq_time: bool = None,
|
||||
agg_freq_module: str = None,
|
||||
agg_time_module: str = None,
|
||||
add_global_repr: bool = True,
|
||||
agg_segments_module: str = None,
|
||||
max_segments: int = None,
|
||||
) -> None:
|
||||
'''
|
||||
extract_features: if True, then the model will return the features instead of head's output
|
||||
ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub.
|
||||
feat_type: if extract_features is True, this parameter specifies the type of features to return
|
||||
max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec
|
||||
factorize_freq_time: if True, then the model will use a factorized freq/time aggregation
|
||||
agg_freq_module: if specified, then the model will use this module for freq aggregation
|
||||
agg_time_module: if specified, then the model will use this module for time aggregation
|
||||
add_global_repr: if True, adds a global representation to the features (aggregation on segments)
|
||||
agg_segments_module: if specified, then the model will use this module for segments aggregation
|
||||
max_segments: if specified, the initialization of PE in the global agg module will use this value.
|
||||
This should correspond to the max number of segments per video (if None, 16 is used)
|
||||
'''
|
||||
super().__init__()
|
||||
self.extract_features = extract_features
|
||||
self.ckpt_path = ckpt_path
|
||||
self.max_spec_t = max_spec_t
|
||||
self.max_segments = max_segments
|
||||
|
||||
# depending on whether the feat extractor was pre-trained contrastively or not, we need to
|
||||
# load the state dict differently.
|
||||
|
||||
# if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model
|
||||
if ckpt_path == 'MIT/ast-finetuned-audioset-10-10-0.4593':
|
||||
revision = 'c1c0c66' # fixing the revision for compatibility (V4.27.4)
|
||||
self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision)
|
||||
full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision)
|
||||
logging.info(f'Loaded AST from {ckpt_path}')
|
||||
else:
|
||||
self.config = ASTConfig()
|
||||
self.config.num_labels = 527 # 2 by default, audioset has 527 labels
|
||||
full_model = ASTForAudioClassification(self.config)
|
||||
logging.info('Initialized AST from scratch with the AST AudioSet config')
|
||||
|
||||
was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith('.pt')
|
||||
|
||||
# feature extractor
|
||||
self.ast = full_model.audio_spectrogram_transformer
|
||||
|
||||
if self.extract_features:
|
||||
# assign `feat_type` (use default if not specified)
|
||||
self.feat_type = 'last_hidden_state' if feat_type is None else feat_type
|
||||
# define adapters if needed
|
||||
self.factorize_freq_time = factorize_freq_time
|
||||
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
||||
transf_enc_layer_kwargs = dict(
|
||||
d_model=self.config.hidden_size, nhead=self.config.num_attention_heads,
|
||||
dim_feedforward=self.config.intermediate_size, activation=nn.GELU(), batch_first=True,
|
||||
dropout=self.config.attention_probs_dropout_prob, layer_norm_eps=1e-6, norm_first=True,
|
||||
)
|
||||
if factorize_freq_time:
|
||||
self.feat_type = 'last_hidden_state' # this feat_type supports factorization
|
||||
# frequency aggreration
|
||||
if agg_freq_module == 'TransformerEncoderLayer':
|
||||
self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
||||
elif agg_freq_module == 'AveragePooling':
|
||||
self.freq_attn_agg = AveragePooling(avg_pattern='BS D f t -> BS D t',
|
||||
then_permute_pattern='BS D t -> BS t D')
|
||||
# time aggreration
|
||||
if agg_time_module == 'TransformerEncoderLayer':
|
||||
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
||||
elif agg_time_module == 'AveragePooling':
|
||||
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
||||
elif 'Identity' in agg_time_module:
|
||||
self.temp_attn_agg = nn.Identity()
|
||||
# define a global aggregation layer (aggregarate over segments)
|
||||
self.add_global_repr = add_global_repr
|
||||
if add_global_repr:
|
||||
if agg_segments_module == 'TransformerEncoderLayer':
|
||||
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
||||
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
||||
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
||||
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
||||
add_pos_emb=True, pos_emb_drop=self.config.hidden_dropout_prob,
|
||||
pos_max_len=pos_max_len, **transf_enc_layer_kwargs
|
||||
)
|
||||
elif agg_segments_module == 'AveragePooling':
|
||||
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
||||
else:
|
||||
self.classifier = full_model.classifier
|
||||
|
||||
# AST.device fails with AttributeError. This is a workaround
|
||||
self.device = full_model.device
|
||||
|
||||
# pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74)
|
||||
self.patch_position_emb()
|
||||
|
||||
if was_pt_on_avclip:
|
||||
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
||||
# and keep only the state_dict of the feat extractor
|
||||
check_if_file_exists_else_download(self.ckpt_path)
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
ckpt_weights = dict()
|
||||
for k, v in ckpt['state_dict'].items():
|
||||
if k.startswith(('module.a_encoder.', 'a_encoder.')):
|
||||
k = k.replace('module.', '').replace('a_encoder.', '')
|
||||
ckpt_weights[k] = v
|
||||
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
||||
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
||||
logging.warning(f'Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n' \
|
||||
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
||||
f'{_load_status.missing_keys}, \n' \
|
||||
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
||||
f'{_load_status.unexpected_keys} \n' \
|
||||
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
|
||||
else:
|
||||
logging.info(f'Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
||||
|
||||
# print the number of parameters
|
||||
logging.info(f'AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}')
|
||||
|
||||
def forward(self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None,
|
||||
**ast_kwargs) -> torch.Tensor:
|
||||
'''
|
||||
x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins,
|
||||
ast_kwargs: additional arguments for the AST model
|
||||
cont_mask: (B, S, T, F) where 0s are the values to be masked out
|
||||
if `for_loop=True`, we use a for loop to extract features for each segment separately.
|
||||
if `for_loop=False`, we extract features for all segments at once.
|
||||
Using the for loop is slower but more memory efficient, while using all segments at once
|
||||
is faster but more memory inefficient.
|
||||
Using for loop allows to control the memory footprint by varying the number of videos in a
|
||||
batch (batch size) rather than the number of segments in a video.
|
||||
'''
|
||||
B, S, T, F = x.shape
|
||||
|
||||
if for_loop:
|
||||
assert cont_mask is None, 'cont_mask is not supported with for_loop=True'
|
||||
orig_shape_s = (B, 1, T, F)
|
||||
# NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F).
|
||||
# (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1.
|
||||
x = torch.cat(
|
||||
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)],
|
||||
dim=1)
|
||||
else:
|
||||
orig_shape = (B, S, T, F)
|
||||
x = x.view(B * S, T, F)
|
||||
if cont_mask is not None:
|
||||
cont_mask = cont_mask.reshape(B * S, T, F)
|
||||
# AST expects a tensor of shape (B*S, T, F).
|
||||
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
|
||||
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
||||
x = x.view(B, S, *x.shape[1:])
|
||||
# x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
|
||||
|
||||
global_x = None
|
||||
if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
|
||||
assert len(x.shape) == 3, f'Local representation should be (B, S, D) {x.shape}'
|
||||
global_x = self.global_attn_agg(x) # (B, D)
|
||||
|
||||
return x, global_x # x is (B, S, ...), global_x is (B, D) or None
|
||||
|
||||
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
|
||||
'''x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out'''
|
||||
# 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, <tokens>]
|
||||
# x_mask is (B, T) where 0s are the values to be masked out
|
||||
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
|
||||
|
||||
if self.extract_features:
|
||||
x = self.get_features_by_type(x)
|
||||
if self.factorize_freq_time:
|
||||
x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
|
||||
if cont_mask is not None:
|
||||
# duplicating the mask for the latent dimension (D) to be compatible with the next func
|
||||
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
|
||||
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
|
||||
# again removing the latent
|
||||
x_mask = x_mask[:, 0, :, :]
|
||||
else:
|
||||
x_mask = None
|
||||
x = self.freq_attn_agg(x, x_mask) # (BS, t, D)
|
||||
x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
|
||||
else:
|
||||
x = x['pooler_output']
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor:
|
||||
if self.feat_type == 'pooler_output':
|
||||
return x['pooler_output'] # (B, D)
|
||||
elif self.feat_type == 'CLS':
|
||||
return x['last_hidden_state'][:, 0, :] # (B, D)
|
||||
elif self.feat_type == 'last_hidden_state':
|
||||
return x['last_hidden_state'] # (B, 2+T, D)
|
||||
elif self.feat_type == 'last_hidden_state_no_AUX':
|
||||
return x['last_hidden_state'][:, 2:, :] # (B, T, D) removing CLS and distill tokens
|
||||
else:
|
||||
raise ValueError(f'Unknown feature type: {self.feat_type}')
|
||||
|
||||
def restore_freq_temp_dims(self, feats, orig_shape: tuple):
|
||||
'''
|
||||
feats are of shape (B*S, T, D)
|
||||
where T = 2 + f * t (if feat_type == 'last_hidden_state')
|
||||
where T = f * t (if feat_type == 'last_hidden_state_no_AUX')
|
||||
Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching.
|
||||
From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats:
|
||||
`feats.transpose(1, 2).view(B*S, D, f, t)`
|
||||
|
||||
(Similar function is defined in for RGB features in `motionformer.py`)
|
||||
'''
|
||||
B, S, T, F = orig_shape
|
||||
D = self.config.hidden_size
|
||||
|
||||
# num patches in each dimension
|
||||
f, t = self.ast.embeddings.get_shape(self.config)
|
||||
|
||||
if self.feat_type == 'last_hidden_state':
|
||||
feats = feats[:, 2:, :] # removing CLS and distill tokens
|
||||
|
||||
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
||||
feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
|
||||
|
||||
return feats
|
||||
|
||||
def patch_position_emb(self):
|
||||
if self.max_spec_t is not None:
|
||||
self.config.max_length = self.max_spec_t
|
||||
f, t = self.ast.embeddings.get_shape(self.config)
|
||||
shortened = self.ast.embeddings.position_embeddings[:, :f*t+2].clone() # +2 for CLS and distill tokens
|
||||
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
|
||||
|
||||
def to(self, device):
|
||||
'''AST.device fails with AttributeError. This is a workaround. '''
|
||||
self.device = torch.device(device)
|
||||
return super().to(device)
|
||||
|
||||
|
||||
class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
|
||||
''' This layer is used to aggregate the features along the frequency axis.
|
||||
It follows the same logic as spatio-temporal aggregation in visual feature extractor.
|
||||
Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py` '''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
||||
''' x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out '''
|
||||
BS, D, f, t = x.shape
|
||||
|
||||
# time as a batch dimension
|
||||
x = x.permute(0, 3, 2, 1) # (B*S, t, f, D)
|
||||
x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory
|
||||
# similar to mask
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f)
|
||||
x_mask = x_mask.reshape(BS * t, f)
|
||||
|
||||
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
||||
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
|
||||
|
||||
# reshape back to (B*S, t, D)
|
||||
x = x.view(BS, t, D)
|
||||
|
||||
return x # (B*S, t, D)
|
||||
@@ -0,0 +1,84 @@
|
||||
TRAIN:
|
||||
ENABLE: True
|
||||
DATASET: Ssv2
|
||||
BATCH_SIZE: 32
|
||||
EVAL_PERIOD: 5
|
||||
CHECKPOINT_PERIOD: 5
|
||||
AUTO_RESUME: True
|
||||
CHECKPOINT_EPOCH_RESET: True
|
||||
CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
|
||||
DATA:
|
||||
NUM_FRAMES: 16
|
||||
SAMPLING_RATE: 4
|
||||
TRAIN_JITTER_SCALES: [256, 320]
|
||||
TRAIN_CROP_SIZE: 224
|
||||
TEST_CROP_SIZE: 224
|
||||
INPUT_CHANNEL_NUM: [3]
|
||||
MEAN: [0.5, 0.5, 0.5]
|
||||
STD: [0.5, 0.5, 0.5]
|
||||
PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
|
||||
PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
|
||||
INV_UNIFORM_SAMPLE: True
|
||||
RANDOM_FLIP: False
|
||||
REVERSE_INPUT_CHANNEL: True
|
||||
USE_RAND_AUGMENT: True
|
||||
RE_PROB: 0.0
|
||||
USE_REPEATED_AUG: False
|
||||
USE_RANDOM_RESIZE_CROPS: False
|
||||
COLORJITTER: False
|
||||
GRAYSCALE: False
|
||||
GAUSSIAN: False
|
||||
SOLVER:
|
||||
BASE_LR: 1e-4
|
||||
LR_POLICY: steps_with_relative_lrs
|
||||
LRS: [1, 0.1, 0.01]
|
||||
STEPS: [0, 20, 30]
|
||||
MAX_EPOCH: 35
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 5e-2
|
||||
WARMUP_EPOCHS: 0.0
|
||||
OPTIMIZING_METHOD: adamw
|
||||
USE_MIXED_PRECISION: True
|
||||
SMOOTHING: 0.2
|
||||
SLOWFAST:
|
||||
ALPHA: 8
|
||||
VIT:
|
||||
PATCH_SIZE: 16
|
||||
PATCH_SIZE_TEMP: 2
|
||||
CHANNELS: 3
|
||||
EMBED_DIM: 768
|
||||
DEPTH: 12
|
||||
NUM_HEADS: 12
|
||||
MLP_RATIO: 4
|
||||
QKV_BIAS: True
|
||||
VIDEO_INPUT: True
|
||||
TEMPORAL_RESOLUTION: 8
|
||||
USE_MLP: True
|
||||
DROP: 0.0
|
||||
POS_DROPOUT: 0.0
|
||||
DROP_PATH: 0.2
|
||||
IM_PRETRAINED: True
|
||||
HEAD_DROPOUT: 0.0
|
||||
HEAD_ACT: tanh
|
||||
PRETRAINED_WEIGHTS: vit_1k
|
||||
ATTN_LAYER: divided
|
||||
MODEL:
|
||||
NUM_CLASSES: 174
|
||||
ARCH: slow
|
||||
MODEL_NAME: VisionTransformer
|
||||
LOSS_FUNC: cross_entropy
|
||||
TEST:
|
||||
ENABLE: True
|
||||
DATASET: Ssv2
|
||||
BATCH_SIZE: 64
|
||||
NUM_ENSEMBLE_VIEWS: 1
|
||||
NUM_SPATIAL_CROPS: 3
|
||||
DATA_LOADER:
|
||||
NUM_WORKERS: 4
|
||||
PIN_MEMORY: True
|
||||
NUM_GPUS: 8
|
||||
NUM_SHARDS: 4
|
||||
RNG_SEED: 0
|
||||
OUTPUT_DIR: .
|
||||
TENSORBOARD:
|
||||
ENABLE: True
|
||||
@@ -0,0 +1,662 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Modified by v-iashin to support token masking
|
||||
|
||||
""" PyTorch Audio Spectrogram Transformer (AST) model."""
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
|
||||
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "ASTConfig"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
|
||||
|
||||
# Audio classification docstring
|
||||
_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 0.17
|
||||
|
||||
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593",
|
||||
# See all Audio Spectrogram Transformer models at https://huggingface.co/models?filter=ast
|
||||
]
|
||||
|
||||
|
||||
class ASTEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct the CLS token, position and patch embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.patch_embeddings = ASTPatchEmbeddings(config)
|
||||
|
||||
frequency_out_dimension, time_out_dimension = self.get_shape(config)
|
||||
num_patches = frequency_out_dimension * time_out_dimension
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
def get_shape(self, config):
|
||||
# see Karpathy's cs231n blog on how to calculate the output dimensions
|
||||
# https://cs231n.github.io/convolutional-networks/#conv
|
||||
frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
|
||||
time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
|
||||
|
||||
return frequency_out_dimension, time_out_dimension
|
||||
|
||||
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = input_values.shape[0]
|
||||
embeddings = self.patch_embeddings(input_values)
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
|
||||
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class ASTPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
|
||||
seq_length, hidden_size)` to be consumed by a Transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
patch_size = config.patch_size
|
||||
frequency_stride = config.frequency_stride
|
||||
time_stride = config.time_stride
|
||||
|
||||
self.projection = nn.Conv2d(
|
||||
1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride)
|
||||
)
|
||||
|
||||
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
||||
input_values = input_values.unsqueeze(1)
|
||||
input_values = input_values.transpose(2, 3)
|
||||
embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
|
||||
class ASTSelfAttention(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# apply masking if provided, tok_mask is (BS, N): 1s - keep; attention_scores is (BS, H, N, N)
|
||||
if tok_mask is not None:
|
||||
BS, N = tok_mask.shape
|
||||
attention_scores = attention_scores.masked_fill(tok_mask.view(BS, 1, 1, N) == 0, float('-inf'))
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
|
||||
class ASTSelfOutput(nn.Module):
|
||||
"""
|
||||
The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
|
||||
layernorm applied before each block.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
|
||||
class ASTAttention(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
self.attention = ASTSelfAttention(config)
|
||||
self.output = ASTSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads: Set[int]) -> None:
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.attention.query = prune_linear_layer(self.attention.query, index)
|
||||
self.attention.key = prune_linear_layer(self.attention.key, index)
|
||||
self.attention.value = prune_linear_layer(self.attention.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
||||
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
self_outputs = self.attention(hidden_states, tok_mask, head_mask, output_attentions)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
|
||||
class ASTIntermediate(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
|
||||
class ASTOutput(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + input_tensor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
|
||||
class ASTLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ASTAttention(config)
|
||||
self.intermediate = ASTIntermediate(config)
|
||||
self.output = ASTOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention
|
||||
tok_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
# in AST, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
|
||||
# second residual connection is done here
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
|
||||
class ASTEncoder(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
tok_mask,
|
||||
layer_head_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
class ASTPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ASTConfig
|
||||
base_model_prefix = "audio_spectrogram_transformer"
|
||||
main_input_name = "input_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
|
||||
def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
|
||||
if isinstance(module, ASTEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`ASTConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
|
||||
[`ASTFeatureExtractor.__call__`] for details.
|
||||
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
|
||||
)
|
||||
class ASTModel(ASTPreTrainedModel):
|
||||
def __init__(self, config: ASTConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = ASTEmbeddings(config)
|
||||
self.encoder = ASTEncoder(config)
|
||||
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> ASTPatchEmbeddings:
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
cont_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_values is None:
|
||||
raise ValueError("You have to specify input_values")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(input_values)
|
||||
|
||||
# transforms the mask that has spectrogram dims to the token masking which is obtained after patching.
|
||||
# Due to the ovelap in patching, getting the token mask from spectrogram mask is not straightforward,
|
||||
# because one 16x16 content patch is encoded in two tokens if stride is <16. So, to get the mask for
|
||||
# tokens I will apply the patching func (self.embeddings) to the tensor with infinities at the masked
|
||||
# content position. For infs, the patching fn will return nans, which I'll use to get the token mask.
|
||||
if cont_mask is not None:
|
||||
indicator = torch.ones_like(input_values).to(input_values.dtype)
|
||||
# replace content mask (0s) with infs
|
||||
indicator[~cont_mask] = torch.inf
|
||||
# apply patching; now nans are where the content mask was
|
||||
with torch.no_grad():
|
||||
indicator = self.embeddings(indicator) # BS, N, D
|
||||
# replace nans with 0s; these are the tokens that correspond to the masked content
|
||||
tok_mask = ~torch.isnan(indicator)
|
||||
# since all values in the D-dimension (latent) will also be nans, we can just use the first el
|
||||
tok_mask = tok_mask[:, :, 0] # (BS, 2+num_patches) -- 2 is from CLS and DISTIL tokens
|
||||
else:
|
||||
tok_mask = None
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
tok_mask=tok_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
|
||||
pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
), tok_mask
|
||||
|
||||
|
||||
class ASTMLPHead(nn.Module):
|
||||
def __init__(self, config: ASTConfig):
|
||||
super().__init__()
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
hidden_state = self.dense(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
|
||||
output) e.g. for datasets like AudioSet, Speech Commands v2.
|
||||
""",
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
|
||||
)
|
||||
class ASTForAudioClassification(ASTPreTrainedModel):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.audio_spectrogram_transformer = ASTModel(config)
|
||||
|
||||
# Classifier head
|
||||
self.classifier = ASTMLPHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
cont_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, SequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.audio_spectrogram_transformer(
|
||||
input_values,
|
||||
cont_mask=cont_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
@@ -0,0 +1,400 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from timm.layers import trunc_normal_
|
||||
from torch import nn
|
||||
|
||||
from selva_core.ext.synchformer.utils import check_if_file_exists_else_download
|
||||
from selva_core.ext.synchformer.video_model_builder import VisionTransformer
|
||||
|
||||
FILE2URL = {
|
||||
# cfg
|
||||
'motionformer_224_16x4.yaml':
|
||||
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
|
||||
'joint_224_16x4.yaml':
|
||||
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
|
||||
'divided_224_16x4.yaml':
|
||||
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
|
||||
# ckpt
|
||||
'ssv2_motionformer_224_16x4.pyth':
|
||||
'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
|
||||
'ssv2_joint_224_16x4.pyth':
|
||||
'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
|
||||
'ssv2_divided_224_16x4.pyth':
|
||||
'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
|
||||
}
|
||||
|
||||
|
||||
class MotionFormer(VisionTransformer):
|
||||
''' This class serves three puposes:
|
||||
1. Renames the class to MotionFormer.
|
||||
2. Downloads the cfg from the original repo and patches it if needed.
|
||||
3. Takes care of feature extraction by redefining .forward()
|
||||
- if `extract_features=True` and `factorize_space_time=False`,
|
||||
the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
||||
- if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
|
||||
and spatial and temporal transformer encoder layers are used.
|
||||
- if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
|
||||
the output is of shape (B, D) and spatial and temporal transformer encoder layers
|
||||
are used as well as the global representation is extracted from segments (extra pos emb
|
||||
is added).
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
extract_features: bool = False,
|
||||
ckpt_path: str = None,
|
||||
factorize_space_time: bool = None,
|
||||
agg_space_module: str = None,
|
||||
agg_time_module: str = None,
|
||||
add_global_repr: bool = True,
|
||||
agg_segments_module: str = None,
|
||||
max_segments: int = None,
|
||||
):
|
||||
self.extract_features = extract_features
|
||||
self.ckpt_path = ckpt_path
|
||||
self.factorize_space_time = factorize_space_time
|
||||
|
||||
if self.ckpt_path is not None:
|
||||
check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
|
||||
ckpt = torch.load(self.ckpt_path, map_location='cpu')
|
||||
mformer_ckpt2cfg = {
|
||||
'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
|
||||
'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
|
||||
'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
|
||||
}
|
||||
# init from motionformer ckpt or from our Stage I ckpt
|
||||
# depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
|
||||
# load the state dict differently
|
||||
was_pt_on_avclip = self.ckpt_path.endswith(
|
||||
'.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
|
||||
if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
|
||||
cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
|
||||
elif was_pt_on_avclip:
|
||||
# TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
|
||||
s1_cfg = ckpt.get('args', None) # Stage I cfg
|
||||
if s1_cfg is not None:
|
||||
s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
|
||||
# if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
|
||||
if s1_vfeat_extractor_ckpt_path is not None:
|
||||
cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
|
||||
else:
|
||||
cfg_fname = 'divided_224_16x4.yaml'
|
||||
else:
|
||||
cfg_fname = 'divided_224_16x4.yaml'
|
||||
else:
|
||||
raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
|
||||
else:
|
||||
was_pt_on_avclip = False
|
||||
cfg_fname = 'divided_224_16x4.yaml'
|
||||
# logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
|
||||
|
||||
if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
|
||||
pos_emb_type = 'separate'
|
||||
elif cfg_fname == 'joint_224_16x4.yaml':
|
||||
pos_emb_type = 'joint'
|
||||
|
||||
self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
|
||||
|
||||
check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
|
||||
mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
|
||||
logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
|
||||
|
||||
# patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
|
||||
mformer_cfg.VIT.ATTN_DROPOUT = 0.0
|
||||
mformer_cfg.VIT.POS_EMBED = pos_emb_type
|
||||
mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
|
||||
mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
|
||||
mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
|
||||
|
||||
# finally init VisionTransformer with the cfg
|
||||
super().__init__(mformer_cfg)
|
||||
|
||||
# load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
|
||||
if (self.ckpt_path is not None) and (not was_pt_on_avclip):
|
||||
_ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
|
||||
if len(_ckpt_load_status.missing_keys) > 0 or len(
|
||||
_ckpt_load_status.unexpected_keys) > 0:
|
||||
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
|
||||
f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
|
||||
f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
|
||||
else:
|
||||
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
||||
|
||||
if self.extract_features:
|
||||
assert isinstance(self.norm,
|
||||
nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
|
||||
# pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
|
||||
self.pre_logits = nn.Identity()
|
||||
# we don't need the classification head (saving memory)
|
||||
self.head = nn.Identity()
|
||||
self.head_drop = nn.Identity()
|
||||
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
||||
transf_enc_layer_kwargs = dict(
|
||||
d_model=self.embed_dim,
|
||||
nhead=self.num_heads,
|
||||
activation=nn.GELU(),
|
||||
batch_first=True,
|
||||
dim_feedforward=self.mlp_ratio * self.embed_dim,
|
||||
dropout=self.drop_rate,
|
||||
layer_norm_eps=1e-6,
|
||||
norm_first=True,
|
||||
)
|
||||
# define adapters if needed
|
||||
if self.factorize_space_time:
|
||||
if agg_space_module == 'TransformerEncoderLayer':
|
||||
self.spatial_attn_agg = SpatialTransformerEncoderLayer(
|
||||
**transf_enc_layer_kwargs)
|
||||
elif agg_space_module == 'AveragePooling':
|
||||
self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
|
||||
then_permute_pattern='BS D t -> BS t D')
|
||||
if agg_time_module == 'TransformerEncoderLayer':
|
||||
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
||||
elif agg_time_module == 'AveragePooling':
|
||||
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
||||
elif 'Identity' in agg_time_module:
|
||||
self.temp_attn_agg = nn.Identity()
|
||||
# define a global aggregation layer (aggregarate over segments)
|
||||
self.add_global_repr = add_global_repr
|
||||
if add_global_repr:
|
||||
if agg_segments_module == 'TransformerEncoderLayer':
|
||||
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
||||
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
||||
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
||||
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
||||
add_pos_emb=True,
|
||||
pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
|
||||
pos_max_len=pos_max_len,
|
||||
**transf_enc_layer_kwargs)
|
||||
elif agg_segments_module == 'AveragePooling':
|
||||
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
||||
|
||||
if was_pt_on_avclip:
|
||||
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
||||
# and keep only the state_dict of the feat extractor
|
||||
ckpt_weights = dict()
|
||||
for k, v in ckpt['state_dict'].items():
|
||||
if k.startswith(('module.v_encoder.', 'v_encoder.')):
|
||||
k = k.replace('module.', '').replace('v_encoder.', '')
|
||||
ckpt_weights[k] = v
|
||||
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
||||
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
||||
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
|
||||
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
||||
f'{_load_status.missing_keys}, \n' \
|
||||
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
||||
f'{_load_status.unexpected_keys} \n' \
|
||||
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
|
||||
else:
|
||||
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
||||
|
||||
# patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
|
||||
# but it used to calculate the number of patches, so we need to set keep it
|
||||
self.patch_embed.requires_grad_(False)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x is of shape (B, S, C, T, H, W) where S is the number of segments.
|
||||
'''
|
||||
# Batch, Segments, Channels, T=frames, Height, Width
|
||||
B, S, C, T, H, W = x.shape
|
||||
# Motionformer expects a tensor of shape (1, B, C, T, H, W).
|
||||
# The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
|
||||
# see `video_model_builder.video_input`.
|
||||
# x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
|
||||
|
||||
orig_shape = (B, S, C, T, H, W)
|
||||
x = x.view(B * S, C, T, H, W) # flatten batch and segments
|
||||
x = self.forward_segments(x, orig_shape=orig_shape)
|
||||
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
||||
x = x.view(B, S, *x.shape[1:])
|
||||
# x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
|
||||
|
||||
return x # x is (B, S, ...)
|
||||
|
||||
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
|
||||
'''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
|
||||
x, x_mask = self.forward_features(x)
|
||||
|
||||
assert self.extract_features
|
||||
|
||||
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
||||
x = x[:,
|
||||
1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
|
||||
x = self.norm(x)
|
||||
x = self.pre_logits(x)
|
||||
if self.factorize_space_time:
|
||||
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
|
||||
|
||||
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
|
||||
x = self.temp_attn_agg(
|
||||
x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
|
||||
|
||||
return x
|
||||
|
||||
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
|
||||
'''
|
||||
feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
||||
Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
|
||||
From `self.patch_embed_3d`, it follows that we could reshape feats with:
|
||||
`feats.transpose(1, 2).view(B*S, D, t, h, w)`
|
||||
'''
|
||||
B, S, C, T, H, W = orig_shape
|
||||
D = self.embed_dim
|
||||
|
||||
# num patches in each dimension
|
||||
t = T // self.patch_embed_3d.z_block_size
|
||||
h = self.patch_embed_3d.height
|
||||
w = self.patch_embed_3d.width
|
||||
|
||||
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
||||
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
|
||||
|
||||
return feats
|
||||
|
||||
|
||||
class BaseEncoderLayer(nn.TransformerEncoderLayer):
|
||||
'''
|
||||
This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
|
||||
to the sequence and outputs the CLS token's representation.
|
||||
This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
|
||||
and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
|
||||
We also, optionally, add a positional embedding to the input sequence which
|
||||
allows to reuse it for global aggregation (of segments) for both streams.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
add_pos_emb: bool = False,
|
||||
pos_emb_drop: float = None,
|
||||
pos_max_len: int = None,
|
||||
*args_transformer_enc,
|
||||
**kwargs_transformer_enc):
|
||||
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
# add positional embedding
|
||||
self.add_pos_emb = add_pos_emb
|
||||
if add_pos_emb:
|
||||
self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
|
||||
self.pos_drop = nn.Dropout(pos_emb_drop)
|
||||
trunc_normal_(self.pos_emb, std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
|
||||
''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
|
||||
batch_dim = x.shape[0]
|
||||
|
||||
# add CLS token
|
||||
cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
|
||||
x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
|
||||
if x_mask is not None:
|
||||
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
|
||||
device=x_mask.device) # 1=keep; 0=mask
|
||||
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
|
||||
B, N = x_mask_w_cls.shape
|
||||
# torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
|
||||
x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
|
||||
.expand(-1, self.self_attn.num_heads, N, -1)\
|
||||
.reshape(B * self.self_attn.num_heads, N, N)
|
||||
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
|
||||
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
|
||||
else:
|
||||
x_mask_w_cls = None
|
||||
|
||||
# add positional embedding
|
||||
if self.add_pos_emb:
|
||||
seq_len = x.shape[
|
||||
1] # (don't even think about moving it before the CLS token concatenation)
|
||||
assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
|
||||
x = x + self.pos_emb[:, :seq_len, :]
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# apply encoder layer (calls nn.TransformerEncoderLayer.forward);
|
||||
x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
|
||||
|
||||
# CLS token is expected to hold spatial information for each frame
|
||||
x = x[:, 0, :] # (batch_dim, D)
|
||||
|
||||
return x
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'cls_token', 'pos_emb'}
|
||||
|
||||
|
||||
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
|
||||
''' Aggregates spatial dimensions by applying attention individually to each frame. '''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
||||
''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
|
||||
if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
|
||||
Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
|
||||
BS, D, t, h, w = x.shape
|
||||
|
||||
# time as a batch dimension and flatten spatial dimensions as sequence
|
||||
x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
|
||||
# similar to mask
|
||||
if x_mask is not None:
|
||||
x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
|
||||
|
||||
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
||||
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
|
||||
|
||||
# reshape back to (B*S, t, D)
|
||||
x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
|
||||
|
||||
# (B*S, t, D)
|
||||
return x
|
||||
|
||||
|
||||
class TemporalTransformerEncoderLayer(BaseEncoderLayer):
|
||||
''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
|
||||
in both streams. '''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
''' x is of shape (B*S, t, D) where S is the number of segments.
|
||||
Returns a tensor of shape (B*S, D) pooling temporal information. '''
|
||||
BS, t, D = x.shape
|
||||
|
||||
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
||||
x = super().forward(x) # (B*S, D)
|
||||
|
||||
return x # (B*S, D)
|
||||
|
||||
|
||||
class AveragePooling(nn.Module):
|
||||
|
||||
def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
|
||||
''' patterns are e.g. "bs t d -> bs d" '''
|
||||
super().__init__()
|
||||
# TODO: need to register them as buffers (but fails because these are strings)
|
||||
self.reduce_fn = 'mean'
|
||||
self.avg_pattern = avg_pattern
|
||||
self.then_permute_pattern = then_permute_pattern
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
||||
x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
|
||||
if self.then_permute_pattern is not None:
|
||||
x = einops.rearrange(x, self.then_permute_pattern)
|
||||
return x
|
||||
@@ -0,0 +1,144 @@
|
||||
import logging
|
||||
from typing import Any, Mapping
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from selva_core.ext.synchformer.motionformer import MotionFormer
|
||||
from selva_core.ext.synchformer.astransformer import AST
|
||||
|
||||
|
||||
class Synchformer(nn.Module):
|
||||
|
||||
def __init__(self, video: bool = True, audio: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.video = video
|
||||
self.audio = audio
|
||||
|
||||
if not video and not audio:
|
||||
raise ValueError('At least one of vis or audio should be True.')
|
||||
|
||||
if self.video:
|
||||
self.vfeat_extractor = MotionFormer(extract_features=True,
|
||||
factorize_space_time=True,
|
||||
agg_space_module='TransformerEncoderLayer',
|
||||
agg_time_module='torch.nn.Identity',
|
||||
add_global_repr=False)
|
||||
if self.audio:
|
||||
self.afeat_extractor = AST(extract_features=True,
|
||||
max_spec_t=66,
|
||||
factorize_freq_time=True,
|
||||
agg_freq_module='TransformerEncoderLayer',
|
||||
agg_time_module='torch.nn.Identity',
|
||||
add_global_repr=False)
|
||||
|
||||
# self.vfeat_extractor = instantiate_from_config(vfeat_extractor)
|
||||
# self.afeat_extractor = instantiate_from_config(afeat_extractor)
|
||||
# # bridging the s3d latent dim (1024) into what is specified in the config
|
||||
# # to match e.g. the transformer dim
|
||||
# self.vproj = instantiate_from_config(vproj)
|
||||
# self.aproj = instantiate_from_config(aproj)
|
||||
# self.transformer = instantiate_from_config(transformer)
|
||||
|
||||
def forward(self, data):
|
||||
video, audio = None, None
|
||||
|
||||
if self.video and self.audio:
|
||||
video, audio = data
|
||||
elif self.video:
|
||||
video = data
|
||||
elif self.audio:
|
||||
audio = data
|
||||
|
||||
if self.video and video is not None:
|
||||
video = self.forward_vfeat(video)
|
||||
if self.audio and audio is not None:
|
||||
audio = self.forward_afeat(audio)
|
||||
|
||||
if self.video and self.audio:
|
||||
return video, audio
|
||||
elif self.video:
|
||||
return video
|
||||
else:
|
||||
return audio
|
||||
|
||||
def forward_vfeat(self, vis):
|
||||
B, S, Tv, C, H, W = vis.shape
|
||||
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
||||
# feat extractors return a tuple of segment-level and global features (ignored for sync)
|
||||
# (B, S, tv, D), e.g. (B, 7, 8, 768)
|
||||
vis = self.vfeat_extractor(vis)
|
||||
return vis
|
||||
|
||||
def forward_afeat(self, aud):
|
||||
B, S, F, Ta = aud.shape
|
||||
aud = aud.permute(0, 1, 3, 2) # (B, S, Ta, F)
|
||||
aud, _ = self.afeat_extractor(aud)
|
||||
return aud
|
||||
|
||||
|
||||
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
|
||||
target_keys = (['vfeat_extractor'] if self.video else []) \
|
||||
+ (['afeat_extractor'] if self.audio else [])
|
||||
# discard all entries except vfeat_extractor / afeat_extractor
|
||||
sd = {k: v for k, v in sd.items() if any(k.startswith(tk)
|
||||
for tk in target_keys)}
|
||||
|
||||
|
||||
return super().load_state_dict(sd, strict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Synchformer(video=True, audio=True).cuda().eval()
|
||||
sd = torch.load('/mnt/hdd3/junwon/mmaudio/ext_weights/synchformer_state_dict.pth', weights_only=True)
|
||||
model.load_state_dict(sd)
|
||||
|
||||
vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
|
||||
features = model.forward_vfeat(vid).detach().cpu()
|
||||
print(features.shape)
|
||||
|
||||
aud = torch.randn(2, 16000*8).cuda()
|
||||
segment_size = 10_240 # 16000 * (16/25) = 16000 * 0.64
|
||||
step_size = 5_120 # segment_size // 2
|
||||
num_segments = (128000 - segment_size) // step_size + 1
|
||||
segments = []
|
||||
for i in range(num_segments):
|
||||
segments.append(aud[:, i * step_size:i * step_size + segment_size])
|
||||
aud = torch.stack(segments, dim=1) # (B, S, T)
|
||||
print(aud.shape)
|
||||
import torchaudio
|
||||
spec = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=16000,
|
||||
win_length=400,
|
||||
hop_length=160,
|
||||
n_fft=1024,
|
||||
n_mels=128,
|
||||
)
|
||||
spec = spec.cuda()
|
||||
aud = spec(aud) # (B, S, F, T)
|
||||
aud = torch.log(aud + 1e-6)
|
||||
max_spec_t = 66
|
||||
if max_spec_t - aud.shape[-1] > 0:
|
||||
# pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
|
||||
pad_dims = (0, max_spec_t - aud.shape[-1])
|
||||
aud = torch.nn.functional.pad(aud, pad_dims,
|
||||
'constant', 0.0)
|
||||
aud = aud[..., :max_spec_t] # (B, S, F, T=66)
|
||||
MEAN = -4.2677393
|
||||
STD = 4.5689974
|
||||
aud = (aud - MEAN) / (2 * STD)
|
||||
print(aud.shape)
|
||||
|
||||
from einops import rearrange
|
||||
aud = rearrange(aud, 'b s f t -> (b s) 1 f t')
|
||||
print(aud.shape)
|
||||
aud = model.forward_afeat(aud).detach().cpu()
|
||||
print(aud.shape)
|
||||
aud = rearrange(aud, '(b s) 1 t d -> b (s t) d', b=2)
|
||||
print(aud.shape)
|
||||
|
||||
|
||||
# extract and save the state dict only
|
||||
# sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
|
||||
# torch.save(sd, './ext_weights/synchformer_state_dict.pth')
|
||||
@@ -0,0 +1,92 @@
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a'
|
||||
FNAME2LINK = {
|
||||
# S3: Synchability: AudioSet (run 2)
|
||||
'24-01-22T20-34-52.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt',
|
||||
'cfg-24-01-22T20-34-52.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml',
|
||||
# S2: Synchformer: AudioSet (run 2)
|
||||
'24-01-04T16-39-21.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt',
|
||||
'cfg-24-01-04T16-39-21.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml',
|
||||
# S2: Synchformer: AudioSet (run 1)
|
||||
'23-08-28T11-23-23.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt',
|
||||
'cfg-23-08-28T11-23-23.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml',
|
||||
# S2: Synchformer: LRS3 (run 2)
|
||||
'23-12-23T18-33-57.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt',
|
||||
'cfg-23-12-23T18-33-57.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml',
|
||||
# S2: Synchformer: VGS (run 2)
|
||||
'24-01-02T10-00-53.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt',
|
||||
'cfg-24-01-02T10-00-53.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml',
|
||||
# SparseSync: ft VGGSound-Full
|
||||
'22-09-21T21-00-52.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt',
|
||||
'cfg-22-09-21T21-00-52.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml',
|
||||
# SparseSync: ft VGGSound-Sparse
|
||||
'22-07-28T15-49-45.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt',
|
||||
'cfg-22-07-28T15-49-45.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml',
|
||||
# SparseSync: only pt on LRS3
|
||||
'22-07-13T22-25-49.pt':
|
||||
f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt',
|
||||
'cfg-22-07-13T22-25-49.yaml':
|
||||
f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml',
|
||||
# SparseSync: feature extractors
|
||||
'ResNetAudio-22-08-04T09-51-04.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s
|
||||
'ResNetAudio-22-08-03T23-14-49.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s
|
||||
'ResNetAudio-22-08-03T23-14-28.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s
|
||||
'ResNetAudio-22-06-24T08-10-33.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s
|
||||
'ResNetAudio-22-06-24T17-31-07.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s
|
||||
'ResNetAudio-22-06-24T23-57-11.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s
|
||||
'ResNetAudio-22-06-25T04-35-42.pt':
|
||||
f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s
|
||||
}
|
||||
|
||||
|
||||
def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
|
||||
'''Checks if file exists, if not downloads it from the link to the path'''
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
link = fname2link.get(path.name, None)
|
||||
if link is None:
|
||||
raise ValueError(f'Cant find the checkpoint file: {path}.',
|
||||
f'Please download it manually and ensure the path exists.')
|
||||
with requests.get(fname2link[path.name], stream=True) as r:
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
|
||||
with open(path, 'wb') as f:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
if data:
|
||||
f.write(data)
|
||||
pbar.update(chunk_size)
|
||||
|
||||
|
||||
def get_md5sum(path):
|
||||
hash_md5 = md5()
|
||||
with open(path, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096 * 8), b''):
|
||||
hash_md5.update(chunk)
|
||||
md5sum = hash_md5.hexdigest()
|
||||
return md5sum
|
||||
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# Copyright 2020 Ross Wightman
|
||||
# Modified Model definition
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.layers import trunc_normal_
|
||||
|
||||
from selva_core.ext.synchformer import vit_helper
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage """
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.img_size = cfg.DATA.TRAIN_CROP_SIZE
|
||||
self.patch_size = cfg.VIT.PATCH_SIZE
|
||||
self.in_chans = cfg.VIT.CHANNELS
|
||||
if cfg.TRAIN.DATASET == "Epickitchens":
|
||||
self.num_classes = [97, 300]
|
||||
else:
|
||||
self.num_classes = cfg.MODEL.NUM_CLASSES
|
||||
self.embed_dim = cfg.VIT.EMBED_DIM
|
||||
self.depth = cfg.VIT.DEPTH
|
||||
self.num_heads = cfg.VIT.NUM_HEADS
|
||||
self.mlp_ratio = cfg.VIT.MLP_RATIO
|
||||
self.qkv_bias = cfg.VIT.QKV_BIAS
|
||||
self.drop_rate = cfg.VIT.DROP
|
||||
self.drop_path_rate = cfg.VIT.DROP_PATH
|
||||
self.head_dropout = cfg.VIT.HEAD_DROPOUT
|
||||
self.video_input = cfg.VIT.VIDEO_INPUT
|
||||
self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
|
||||
self.use_mlp = cfg.VIT.USE_MLP
|
||||
self.num_features = self.embed_dim
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
|
||||
self.head_act = cfg.VIT.HEAD_ACT
|
||||
self.cfg = cfg
|
||||
|
||||
# Patch Embedding
|
||||
self.patch_embed = vit_helper.PatchEmbed(img_size=224,
|
||||
patch_size=self.patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim)
|
||||
|
||||
# 3D Patch Embedding
|
||||
self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size,
|
||||
temporal_resolution=self.temporal_resolution,
|
||||
patch_size=self.patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP)
|
||||
self.patch_embed_3d.proj.weight.data = torch.zeros_like(
|
||||
self.patch_embed_3d.proj.weight.data)
|
||||
|
||||
# Number of patches
|
||||
if self.video_input:
|
||||
num_patches = self.patch_embed.num_patches * self.temporal_resolution
|
||||
else:
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.num_patches = num_patches
|
||||
|
||||
# CLS token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
# Positional embedding
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
if self.cfg.VIT.POS_EMBED == "joint":
|
||||
self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
||||
trunc_normal_(self.st_embed, std=.02)
|
||||
elif self.cfg.VIT.POS_EMBED == "separate":
|
||||
self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
|
||||
|
||||
# Layer Blocks
|
||||
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
|
||||
if self.cfg.VIT.ATTN_LAYER == "divided":
|
||||
self.blocks = nn.ModuleList([
|
||||
vit_helper.DividedSpaceTimeBlock(
|
||||
attn_type=cfg.VIT.ATTN_LAYER,
|
||||
dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=self.qkv_bias,
|
||||
drop=self.drop_rate,
|
||||
attn_drop=self.attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
) for i in range(self.depth)
|
||||
])
|
||||
else:
|
||||
self.blocks = nn.ModuleList([
|
||||
vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER,
|
||||
dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=self.qkv_bias,
|
||||
drop=self.drop_rate,
|
||||
attn_drop=self.attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE)
|
||||
for i in range(self.depth)
|
||||
])
|
||||
self.norm = norm_layer(self.embed_dim)
|
||||
|
||||
# MLP head
|
||||
if self.use_mlp:
|
||||
hidden_dim = self.embed_dim
|
||||
if self.head_act == 'tanh':
|
||||
# logging.info("Using TanH activation in MLP")
|
||||
act = nn.Tanh()
|
||||
elif self.head_act == 'gelu':
|
||||
# logging.info("Using GELU activation in MLP")
|
||||
act = nn.GELU()
|
||||
else:
|
||||
# logging.info("Using ReLU activation in MLP")
|
||||
act = nn.ReLU()
|
||||
self.pre_logits = nn.Sequential(
|
||||
OrderedDict([
|
||||
('fc', nn.Linear(self.embed_dim, hidden_dim)),
|
||||
('act', act),
|
||||
]))
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
self.head_drop = nn.Dropout(p=self.head_dropout)
|
||||
if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
||||
for a, i in enumerate(range(len(self.num_classes))):
|
||||
setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
|
||||
else:
|
||||
self.head = nn.Linear(self.embed_dim,
|
||||
self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
if self.cfg.VIT.POS_EMBED == "joint":
|
||||
return {'pos_embed', 'cls_token', 'st_embed'}
|
||||
else:
|
||||
return {'pos_embed', 'cls_token', 'temp_embed'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity())
|
||||
|
||||
def forward_features(self, x):
|
||||
# if self.video_input:
|
||||
# x = x[0]
|
||||
B = x.shape[0]
|
||||
|
||||
# Tokenize input
|
||||
# if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
|
||||
# for simplicity of mapping between content dimensions (input x) and token dims (after patching)
|
||||
# we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
|
||||
|
||||
# apply patching on input
|
||||
x = self.patch_embed_3d(x)
|
||||
tok_mask = None
|
||||
|
||||
# else:
|
||||
# tok_mask = None
|
||||
# # 2D tokenization
|
||||
# if self.video_input:
|
||||
# x = x.permute(0, 2, 1, 3, 4)
|
||||
# (B, T, C, H, W) = x.shape
|
||||
# x = x.reshape(B * T, C, H, W)
|
||||
|
||||
# x = self.patch_embed(x)
|
||||
|
||||
# if self.video_input:
|
||||
# (B2, T2, D2) = x.shape
|
||||
# x = x.reshape(B, T * T2, D2)
|
||||
|
||||
# Append CLS token
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
# if tok_mask is not None:
|
||||
# # prepend 1(=keep) to the mask to account for the CLS token as well
|
||||
# tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
|
||||
|
||||
# Interpolate positinoal embeddings
|
||||
# if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
|
||||
# pos_embed = self.pos_embed
|
||||
# N = pos_embed.shape[1] - 1
|
||||
# npatch = int((x.size(1) - 1) / self.temporal_resolution)
|
||||
# class_emb = pos_embed[:, 0]
|
||||
# pos_embed = pos_embed[:, 1:]
|
||||
# dim = x.shape[-1]
|
||||
# pos_embed = torch.nn.functional.interpolate(
|
||||
# pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||
# scale_factor=math.sqrt(npatch / N),
|
||||
# mode='bicubic',
|
||||
# )
|
||||
# pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
# new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
||||
# else:
|
||||
new_pos_embed = self.pos_embed
|
||||
npatch = self.patch_embed.num_patches
|
||||
|
||||
# Add positional embeddings to input
|
||||
if self.video_input:
|
||||
if self.cfg.VIT.POS_EMBED == "separate":
|
||||
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
||||
tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
|
||||
tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
|
||||
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
||||
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
|
||||
x = x + total_pos_embed
|
||||
elif self.cfg.VIT.POS_EMBED == "joint":
|
||||
x = x + self.st_embed
|
||||
else:
|
||||
# image input
|
||||
x = x + new_pos_embed
|
||||
|
||||
# Apply positional dropout
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# Encoding using transformer layers
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x,
|
||||
seq_len=npatch,
|
||||
num_frames=self.temporal_resolution,
|
||||
approx=self.cfg.VIT.APPROX_ATTN_TYPE,
|
||||
num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
|
||||
tok_mask=tok_mask)
|
||||
|
||||
### v-iashin: I moved it to the forward pass
|
||||
# x = self.norm(x)[:, 0]
|
||||
# x = self.pre_logits(x)
|
||||
###
|
||||
return x, tok_mask
|
||||
|
||||
# def forward(self, x):
|
||||
# x = self.forward_features(x)
|
||||
# ### v-iashin: here. This should leave the same forward output as before
|
||||
# x = self.norm(x)[:, 0]
|
||||
# x = self.pre_logits(x)
|
||||
# ###
|
||||
# x = self.head_drop(x)
|
||||
# if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
||||
# output = []
|
||||
# for head in range(len(self.num_classes)):
|
||||
# x_out = getattr(self, "head%d" % head)(x)
|
||||
# if not self.training:
|
||||
# x_out = torch.nn.functional.softmax(x_out, dim=-1)
|
||||
# output.append(x_out)
|
||||
# return output
|
||||
# else:
|
||||
# x = self.head(x)
|
||||
# if not self.training:
|
||||
# x = torch.nn.functional.softmax(x, dim=-1)
|
||||
# return x
|
||||
@@ -0,0 +1,399 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# Copyright 2020 Ross Wightman
|
||||
# Modified Model definition
|
||||
"""Video models."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from timm.layers import to_2tuple
|
||||
from torch import einsum
|
||||
from torch.nn import functional as F
|
||||
|
||||
default_cfgs = {
|
||||
'vit_1k':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||
'vit_1k_large':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
||||
}
|
||||
|
||||
|
||||
def qkv_attn(q, k, v, tok_mask: torch.Tensor = None):
|
||||
sim = einsum('b i d, b j d -> b i j', q, k)
|
||||
# apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N)
|
||||
if tok_mask is not None:
|
||||
BSH, N = tok_mask.shape
|
||||
sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0,
|
||||
float('-inf')) # 1 - broadcasts across N
|
||||
attn = sim.softmax(dim=-1)
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
return out
|
||||
|
||||
|
||||
class DividedAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
# init to zeros
|
||||
self.qkv.weight.data.fill_(0)
|
||||
self.qkv.bias.data.fill_(0)
|
||||
self.proj.weight.data.fill_(1)
|
||||
self.proj.bias.data.fill_(0)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
|
||||
# num of heads variable
|
||||
h = self.num_heads
|
||||
|
||||
# project x to q, k, v vaalues
|
||||
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
if tok_mask is not None:
|
||||
# replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d
|
||||
assert len(tok_mask.shape) == 2
|
||||
tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1])
|
||||
|
||||
# Scale q
|
||||
q *= self.scale
|
||||
|
||||
# Take out cls_q, cls_k, cls_v
|
||||
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
|
||||
# the same for masking
|
||||
if tok_mask is not None:
|
||||
cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:]
|
||||
else:
|
||||
cls_mask, mask_ = None, None
|
||||
|
||||
# let CLS token attend to key / values of all patches across time and space
|
||||
cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask)
|
||||
|
||||
# rearrange across time or space
|
||||
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims),
|
||||
(q_, k_, v_))
|
||||
|
||||
# expand CLS token keys and values across time or space and concat
|
||||
r = q_.shape[0] // cls_k.shape[0]
|
||||
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
|
||||
|
||||
k_ = torch.cat((cls_k, k_), dim=1)
|
||||
v_ = torch.cat((cls_v, v_), dim=1)
|
||||
|
||||
# the same for masking (if provided)
|
||||
if tok_mask is not None:
|
||||
# since mask does not have the latent dim (d), we need to remove it from einops dims
|
||||
mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''),
|
||||
**einops_dims)
|
||||
cls_mask = repeat(cls_mask, 'b () -> (b r) ()',
|
||||
r=r) # expand cls_mask across time or space
|
||||
mask_ = torch.cat((cls_mask, mask_), dim=1)
|
||||
|
||||
# attention
|
||||
out = qkv_attn(q_, k_, v_, tok_mask=mask_)
|
||||
|
||||
# merge back time or space
|
||||
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
|
||||
|
||||
# concat back the cls token
|
||||
out = torch.cat((cls_out, out), dim=1)
|
||||
|
||||
# merge back the heads
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## to out
|
||||
x = self.proj(out)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class DividedSpaceTimeBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=768,
|
||||
num_heads=12,
|
||||
attn_type='divided',
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
|
||||
self.einops_from_space = 'b (f n) d'
|
||||
self.einops_to_space = '(b f) n d'
|
||||
self.einops_from_time = 'b (f n) d'
|
||||
self.einops_to_time = '(b n) f d'
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = DividedAttention(dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
self.timeattn = DividedAttention(dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
self.norm3 = norm_layer(dim)
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
seq_len=196,
|
||||
num_frames=8,
|
||||
approx='none',
|
||||
num_landmarks=128,
|
||||
tok_mask: torch.Tensor = None):
|
||||
time_output = self.timeattn(self.norm3(x),
|
||||
self.einops_from_time,
|
||||
self.einops_to_time,
|
||||
n=seq_len,
|
||||
tok_mask=tok_mask)
|
||||
time_residual = x + time_output
|
||||
|
||||
space_output = self.attn(self.norm1(time_residual),
|
||||
self.einops_from_space,
|
||||
self.einops_to_space,
|
||||
f=num_frames,
|
||||
tok_mask=tok_mask)
|
||||
space_residual = time_residual + self.drop_path(space_output)
|
||||
|
||||
x = space_residual
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = img_size if type(img_size) is tuple else to_2tuple(img_size)
|
||||
patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed3D(nn.Module):
|
||||
""" Image to Patch Embedding """
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
temporal_resolution=4,
|
||||
in_chans=3,
|
||||
patch_size=16,
|
||||
z_block_size=2,
|
||||
embed_dim=768,
|
||||
flatten=True):
|
||||
super().__init__()
|
||||
self.height = (img_size // patch_size)
|
||||
self.width = (img_size // patch_size)
|
||||
### v-iashin: these two are incorrect
|
||||
# self.frames = (temporal_resolution // z_block_size)
|
||||
# self.num_patches = self.height * self.width * self.frames
|
||||
self.z_block_size = z_block_size
|
||||
###
|
||||
self.proj = nn.Conv3d(in_chans,
|
||||
embed_dim,
|
||||
kernel_size=(z_block_size, patch_size, patch_size),
|
||||
stride=(z_block_size, patch_size, patch_size))
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T, H, W = x.shape
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class HeadMLP(nn.Module):
|
||||
|
||||
def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
|
||||
super(HeadMLP, self).__init__()
|
||||
self.n_input = n_input
|
||||
self.n_classes = n_classes
|
||||
self.n_hidden = n_hidden
|
||||
if n_hidden is None:
|
||||
# use linear classifier
|
||||
self.block_forward = nn.Sequential(nn.Dropout(p=p),
|
||||
nn.Linear(n_input, n_classes, bias=True))
|
||||
else:
|
||||
# use simple MLP classifier
|
||||
self.block_forward = nn.Sequential(nn.Dropout(p=p),
|
||||
nn.Linear(n_input, n_hidden, bias=True),
|
||||
nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True),
|
||||
nn.Dropout(p=p),
|
||||
nn.Linear(n_hidden, n_classes, bias=True))
|
||||
print(f"Dropout-NLP: {p}")
|
||||
|
||||
def forward(self, x):
|
||||
return self.block_forward(x)
|
||||
|
||||
|
||||
def _conv_filter(state_dict, patch_size=16):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def adapt_input_conv(in_chans, conv_weight, agg='sum'):
|
||||
conv_type = conv_weight.dtype
|
||||
conv_weight = conv_weight.float()
|
||||
O, I, J, K = conv_weight.shape
|
||||
if in_chans == 1:
|
||||
if I > 3:
|
||||
assert conv_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
if agg == 'sum':
|
||||
print("Summing conv1 weights")
|
||||
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
||||
else:
|
||||
print("Averaging conv1 weights")
|
||||
conv_weight = conv_weight.mean(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
if I != 3:
|
||||
raise NotImplementedError('Weight format not supported by conversion.')
|
||||
else:
|
||||
if agg == 'sum':
|
||||
print("Summing conv1 weights")
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv_weight *= (3 / float(in_chans))
|
||||
else:
|
||||
print("Averaging conv1 weights")
|
||||
conv_weight = conv_weight.mean(dim=1, keepdim=True)
|
||||
conv_weight = conv_weight.repeat(1, in_chans, 1, 1)
|
||||
conv_weight = conv_weight.to(conv_type)
|
||||
return conv_weight
|
||||
|
||||
|
||||
def load_pretrained(model,
|
||||
cfg=None,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
filter_fn=None,
|
||||
strict=True,
|
||||
progress=False):
|
||||
# Load state dict
|
||||
assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]")
|
||||
state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS])
|
||||
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
input_convs = 'patch_embed.proj'
|
||||
if input_convs is not None and in_chans != 3:
|
||||
if isinstance(input_convs, str):
|
||||
input_convs = (input_convs, )
|
||||
for input_conv_name in input_convs:
|
||||
weight_name = input_conv_name + '.weight'
|
||||
try:
|
||||
state_dict[weight_name] = adapt_input_conv(in_chans,
|
||||
state_dict[weight_name],
|
||||
agg='avg')
|
||||
print(
|
||||
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)'
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
del state_dict[weight_name]
|
||||
strict = False
|
||||
print(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.'
|
||||
)
|
||||
|
||||
classifier_name = 'head'
|
||||
label_offset = cfg.get('label_offset', 0)
|
||||
pretrain_classes = 1000
|
||||
if num_classes != pretrain_classes:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
loaded_state = state_dict
|
||||
self_state = model.state_dict()
|
||||
all_names = set(self_state.keys())
|
||||
saved_names = set([])
|
||||
for name, param in loaded_state.items():
|
||||
param = param
|
||||
if 'module.' in name:
|
||||
name = name.replace('module.', '')
|
||||
if name in self_state.keys() and param.shape == self_state[name].shape:
|
||||
saved_names.add(name)
|
||||
self_state[name].copy_(param)
|
||||
else:
|
||||
print(f"didnt load: {name} of shape: {param.shape}")
|
||||
print("Missing Keys:")
|
||||
print(all_names - saved_names)
|
||||
Reference in New Issue
Block a user