feat: add AST (Audio Spectrogram Transformer) embedding model
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -45,6 +45,7 @@ os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
||||
_w2v_model = None
|
||||
_w2v_device = None
|
||||
_w2v_model_name = None
|
||||
_ast_feature_extractor = None
|
||||
|
||||
# Supported embedding models — name → embed_dim
|
||||
_EMBED_MODELS = {
|
||||
@@ -60,6 +61,9 @@ _EMBED_MODELS = {
|
||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||
# Transformers-based models
|
||||
"AST": 768,
|
||||
"AST_ML": 3072, # 768 * 4
|
||||
}
|
||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||
|
||||
@@ -90,6 +94,15 @@ def _get_w2v_model(model_name: str | None = None):
|
||||
_w2v_model = BEATs(cfg)
|
||||
_w2v_model.load_state_dict(checkpoint['model'])
|
||||
_w2v_model.to(_w2v_device)
|
||||
elif load_name == "AST":
|
||||
from transformers import ASTModel, ASTFeatureExtractor
|
||||
_w2v_model = ASTModel.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
).to(_w2v_device)
|
||||
global _ast_feature_extractor
|
||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
)
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
@@ -204,6 +217,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
import torch
|
||||
model, device = _get_w2v_model(model_name)
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
# Auto-size batches based on available GPU memory
|
||||
batch_size = 16
|
||||
@@ -234,6 +248,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
@@ -313,6 +340,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
embeddings_list: list[np.ndarray] = []
|
||||
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
|
||||
for batch_start in range(0, len(valid_times), batch_size):
|
||||
@@ -328,6 +356,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
|
||||
@@ -47,6 +47,12 @@ def test_ml_config():
|
||||
assert layers == [2, 5, 8, 11]
|
||||
|
||||
|
||||
def test_embed_dim_ast():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("AST") == 768
|
||||
assert _embed_dim("AST_ML") == 3072
|
||||
|
||||
|
||||
def test_db_get_all_export_paths():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
Reference in New Issue
Block a user