diff --git a/core/audio_scan.py b/core/audio_scan.py index ad21caf..9dbb59c 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -45,6 +45,7 @@ os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface")) _w2v_model = None _w2v_device = None _w2v_model_name = None +_ast_feature_extractor = None # Supported embedding models — name → embed_dim _EMBED_MODELS = { @@ -60,6 +61,9 @@ _EMBED_MODELS = { "HUBERT_BASE_ML": 3072, # 768 * 4 "HUBERT_LARGE_ML": 4096, # 1024 * 4 "HUBERT_XLARGE_ML": 5120, # 1280 * 4 + # Transformers-based models + "AST": 768, + "AST_ML": 3072, # 768 * 4 } _DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" @@ -90,6 +94,15 @@ def _get_w2v_model(model_name: str | None = None): _w2v_model = BEATs(cfg) _w2v_model.load_state_dict(checkpoint['model']) _w2v_model.to(_w2v_device) + elif load_name == "AST": + from transformers import ASTModel, ASTFeatureExtractor + _w2v_model = ASTModel.from_pretrained( + "MIT/ast-finetuned-audioset-10-10-0.4593" + ).to(_w2v_device) + global _ast_feature_extractor + _ast_feature_extractor = ASTFeatureExtractor.from_pretrained( + "MIT/ast-finetuned-audioset-10-10-0.4593" + ) else: import torchaudio bundle = getattr(torchaudio.pipelines, load_name) @@ -204,6 +217,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, import torch model, device = _get_w2v_model(model_name) is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML") ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) # Auto-size batches based on available GPU memory batch_size = 16 @@ -234,6 +248,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) features, _ = model.extract_features(waveforms, padding_mask=padding_mask) batch_emb = features.mean(dim=1).cpu().numpy() + elif is_ast: + inputs = _ast_feature_extractor( + list(chunks), sampling_rate=sr, return_tensors="pt", + padding=True, + ) + input_values = inputs.input_values.to(device) + if ml_cfg is not None: + out = model(input_values, output_hidden_states=True) + selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]] + batch_emb = torch.cat(selected, dim=1).cpu().numpy() + else: + out = model(input_values) + batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy() elif ml_cfg is not None: all_layers, _ = model.extract_features(waveforms) selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] @@ -313,6 +340,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], embeddings_list: list[np.ndarray] = [] is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML") ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) for batch_start in range(0, len(valid_times), batch_size): @@ -328,6 +356,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) features, _ = model.extract_features(waveforms, padding_mask=padding_mask) batch_emb = features.mean(dim=1).cpu().numpy() + elif is_ast: + inputs = _ast_feature_extractor( + list(chunks), sampling_rate=sr, return_tensors="pt", + padding=True, + ) + input_values = inputs.input_values.to(device) + if ml_cfg is not None: + out = model(input_values, output_hidden_states=True) + selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]] + batch_emb = torch.cat(selected, dim=1).cpu().numpy() + else: + out = model(input_values) + batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy() elif ml_cfg is not None: all_layers, _ = model.extract_features(waveforms) selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] diff --git a/tests/test_audio_scan.py b/tests/test_audio_scan.py index 38c7f00..baf34ad 100644 --- a/tests/test_audio_scan.py +++ b/tests/test_audio_scan.py @@ -47,6 +47,12 @@ def test_ml_config(): assert layers == [2, 5, 8, 11] +def test_embed_dim_ast(): + from core.audio_scan import _embed_dim + assert _embed_dim("AST") == 768 + assert _embed_dim("AST_ML") == 3072 + + def test_db_get_all_export_paths(): with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: path = f.name