feat: add AST (Audio Spectrogram Transformer) embedding model
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -45,6 +45,7 @@ os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
|||||||
_w2v_model = None
|
_w2v_model = None
|
||||||
_w2v_device = None
|
_w2v_device = None
|
||||||
_w2v_model_name = None
|
_w2v_model_name = None
|
||||||
|
_ast_feature_extractor = None
|
||||||
|
|
||||||
# Supported embedding models — name → embed_dim
|
# Supported embedding models — name → embed_dim
|
||||||
_EMBED_MODELS = {
|
_EMBED_MODELS = {
|
||||||
@@ -60,6 +61,9 @@ _EMBED_MODELS = {
|
|||||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||||
|
# Transformers-based models
|
||||||
|
"AST": 768,
|
||||||
|
"AST_ML": 3072, # 768 * 4
|
||||||
}
|
}
|
||||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||||
|
|
||||||
@@ -90,6 +94,15 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
_w2v_model = BEATs(cfg)
|
_w2v_model = BEATs(cfg)
|
||||||
_w2v_model.load_state_dict(checkpoint['model'])
|
_w2v_model.load_state_dict(checkpoint['model'])
|
||||||
_w2v_model.to(_w2v_device)
|
_w2v_model.to(_w2v_device)
|
||||||
|
elif load_name == "AST":
|
||||||
|
from transformers import ASTModel, ASTFeatureExtractor
|
||||||
|
_w2v_model = ASTModel.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
).to(_w2v_device)
|
||||||
|
global _ast_feature_extractor
|
||||||
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
import torchaudio
|
import torchaudio
|
||||||
bundle = getattr(torchaudio.pipelines, load_name)
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
@@ -204,6 +217,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
import torch
|
import torch
|
||||||
model, device = _get_w2v_model(model_name)
|
model, device = _get_w2v_model(model_name)
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
# Auto-size batches based on available GPU memory
|
# Auto-size batches based on available GPU memory
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
@@ -234,6 +248,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_ast:
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
elif ml_cfg is not None:
|
elif ml_cfg is not None:
|
||||||
all_layers, _ = model.extract_features(waveforms)
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
@@ -313,6 +340,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
embeddings_list: list[np.ndarray] = []
|
embeddings_list: list[np.ndarray] = []
|
||||||
|
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
|
||||||
for batch_start in range(0, len(valid_times), batch_size):
|
for batch_start in range(0, len(valid_times), batch_size):
|
||||||
@@ -328,6 +356,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_ast:
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
elif ml_cfg is not None:
|
elif ml_cfg is not None:
|
||||||
all_layers, _ = model.extract_features(waveforms)
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ def test_ml_config():
|
|||||||
assert layers == [2, 5, 8, 11]
|
assert layers == [2, 5, 8, 11]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_ast():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("AST") == 768
|
||||||
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
|
||||||
|
|
||||||
def test_db_get_all_export_paths():
|
def test_db_get_all_export_paths():
|
||||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
path = f.name
|
path = f.name
|
||||||
|
|||||||
Reference in New Issue
Block a user