feat: add AST (Audio Spectrogram Transformer) embedding model

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 13:55:29 +02:00
parent e3f133ef84
commit 5b25e85e98
2 changed files with 47 additions and 0 deletions
+41
View File
@@ -45,6 +45,7 @@ os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
_w2v_model = None
_w2v_device = None
_w2v_model_name = None
_ast_feature_extractor = None
# Supported embedding models — name → embed_dim
_EMBED_MODELS = {
@@ -60,6 +61,9 @@ _EMBED_MODELS = {
"HUBERT_BASE_ML": 3072, # 768 * 4
"HUBERT_LARGE_ML": 4096, # 1024 * 4
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
# Transformers-based models
"AST": 768,
"AST_ML": 3072, # 768 * 4
}
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
@@ -90,6 +94,15 @@ def _get_w2v_model(model_name: str | None = None):
_w2v_model = BEATs(cfg)
_w2v_model.load_state_dict(checkpoint['model'])
_w2v_model.to(_w2v_device)
elif load_name == "AST":
from transformers import ASTModel, ASTFeatureExtractor
_w2v_model = ASTModel.from_pretrained(
"MIT/ast-finetuned-audioset-10-10-0.4593"
).to(_w2v_device)
global _ast_feature_extractor
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
"MIT/ast-finetuned-audioset-10-10-0.4593"
)
else:
import torchaudio
bundle = getattr(torchaudio.pipelines, load_name)
@@ -204,6 +217,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
import torch
model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
# Auto-size batches based on available GPU memory
batch_size = 16
@@ -234,6 +248,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif is_ast:
inputs = _ast_feature_extractor(
list(chunks), sampling_rate=sr, return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(device)
if ml_cfg is not None:
out = model(input_values, output_hidden_states=True)
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
out = model(input_values)
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
@@ -313,6 +340,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
embeddings_list: list[np.ndarray] = []
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
for batch_start in range(0, len(valid_times), batch_size):
@@ -328,6 +356,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif is_ast:
inputs = _ast_feature_extractor(
list(chunks), sampling_rate=sr, return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(device)
if ml_cfg is not None:
out = model(input_values, output_hidden_states=True)
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
out = model(input_values)
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
+6
View File
@@ -47,6 +47,12 @@ def test_ml_config():
assert layers == [2, 5, 8, 11]
def test_embed_dim_ast():
from core.audio_scan import _embed_dim
assert _embed_dim("AST") == 768
assert _embed_dim("AST_ML") == 3072
def test_db_get_all_export_paths():
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name