diff --git a/core/audio_scan.py b/core/audio_scan.py index 9dbb59c..833e7b2 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -64,6 +64,7 @@ _EMBED_MODELS = { # Transformers-based models "AST": 768, "AST_ML": 3072, # 768 * 4 + "EAT": 768, } _DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" @@ -103,6 +104,12 @@ def _get_w2v_model(model_name: str | None = None): _ast_feature_extractor = ASTFeatureExtractor.from_pretrained( "MIT/ast-finetuned-audioset-10-10-0.4593" ) + elif load_name == "EAT": + from transformers import AutoModel + _w2v_model = AutoModel.from_pretrained( + "worstchan/EAT-base_epoch30_finetune_AS2M", + trust_remote_code=True, + ).to(_w2v_device) else: import torchaudio bundle = getattr(torchaudio.pipelines, load_name) @@ -114,6 +121,35 @@ def _get_w2v_model(model_name: str | None = None): return _w2v_model, _w2v_device +def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str): + """Convert raw audio chunks to EAT mel spectrogram input. + + Returns tensor of shape [B, 1, T, 128]. + 8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024. + """ + import torch + import torchaudio.compliance.kaldi as kaldi + + TARGET_LEN = 1024 + MEAN, STD = -4.268, 4.569 + + mels = [] + for chunk in chunks: + wav = torch.from_numpy(chunk).unsqueeze(0).float() + fbank = kaldi.fbank( + wav, htk_compat=True, sample_frequency=sr, use_energy=False, + window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10, + ) + # Pad or truncate to TARGET_LEN + if fbank.shape[0] < TARGET_LEN: + fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0])) + else: + fbank = fbank[:TARGET_LEN] + fbank = (fbank - MEAN) / (STD * 2) + mels.append(fbank) + return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128] + + def _embed_dim(model_name: str | None = None) -> int: """Return embedding dimension for a model name.""" if model_name is None: @@ -218,6 +254,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, model, device = _get_w2v_model(model_name) is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML") + is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT" ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) # Auto-size batches based on available GPU memory batch_size = 16 @@ -261,6 +298,10 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, else: out = model(input_values) batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy() + elif is_eat: + mel_input = _eat_preprocess(chunks, sr, device) + features = model.extract_features(mel_input) + batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy() elif ml_cfg is not None: all_layers, _ = model.extract_features(waveforms) selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] @@ -341,6 +382,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML") + is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT" ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) for batch_start in range(0, len(valid_times), batch_size): @@ -369,6 +411,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], else: out = model(input_values) batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy() + elif is_eat: + mel_input = _eat_preprocess(chunks, sr, device) + features = model.extract_features(mel_input) + batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy() elif ml_cfg is not None: all_layers, _ = model.extract_features(waveforms) selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] diff --git a/tests/test_audio_scan.py b/tests/test_audio_scan.py index baf34ad..1de0d5f 100644 --- a/tests/test_audio_scan.py +++ b/tests/test_audio_scan.py @@ -53,6 +53,11 @@ def test_embed_dim_ast(): assert _embed_dim("AST_ML") == 3072 +def test_embed_dim_eat(): + from core.audio_scan import _embed_dim + assert _embed_dim("EAT") == 768 + + def test_db_get_all_export_paths(): with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: path = f.name