feat: add EAT (Efficient Audio Transformer) embedding model

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 14:00:09 +02:00
parent 5b25e85e98
commit 8fb8581816
2 changed files with 51 additions and 0 deletions
+46
View File
@@ -64,6 +64,7 @@ _EMBED_MODELS = {
# Transformers-based models # Transformers-based models
"AST": 768, "AST": 768,
"AST_ML": 3072, # 768 * 4 "AST_ML": 3072, # 768 * 4
"EAT": 768,
} }
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" _DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
@@ -103,6 +104,12 @@ def _get_w2v_model(model_name: str | None = None):
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained( _ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
"MIT/ast-finetuned-audioset-10-10-0.4593" "MIT/ast-finetuned-audioset-10-10-0.4593"
) )
elif load_name == "EAT":
from transformers import AutoModel
_w2v_model = AutoModel.from_pretrained(
"worstchan/EAT-base_epoch30_finetune_AS2M",
trust_remote_code=True,
).to(_w2v_device)
else: else:
import torchaudio import torchaudio
bundle = getattr(torchaudio.pipelines, load_name) bundle = getattr(torchaudio.pipelines, load_name)
@@ -114,6 +121,35 @@ def _get_w2v_model(model_name: str | None = None):
return _w2v_model, _w2v_device return _w2v_model, _w2v_device
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
"""Convert raw audio chunks to EAT mel spectrogram input.
Returns tensor of shape [B, 1, T, 128].
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
"""
import torch
import torchaudio.compliance.kaldi as kaldi
TARGET_LEN = 1024
MEAN, STD = -4.268, 4.569
mels = []
for chunk in chunks:
wav = torch.from_numpy(chunk).unsqueeze(0).float()
fbank = kaldi.fbank(
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
)
# Pad or truncate to TARGET_LEN
if fbank.shape[0] < TARGET_LEN:
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
else:
fbank = fbank[:TARGET_LEN]
fbank = (fbank - MEAN) / (STD * 2)
mels.append(fbank)
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
def _embed_dim(model_name: str | None = None) -> int: def _embed_dim(model_name: str | None = None) -> int:
"""Return embedding dimension for a model name.""" """Return embedding dimension for a model name."""
if model_name is None: if model_name is None:
@@ -218,6 +254,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
model, device = _get_w2v_model(model_name) model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML") is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
# Auto-size batches based on available GPU memory # Auto-size batches based on available GPU memory
batch_size = 16 batch_size = 16
@@ -261,6 +298,10 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
else: else:
out = model(input_values) out = model(input_values)
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy() batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
elif is_eat:
mel_input = _eat_preprocess(chunks, sr, device)
features = model.extract_features(mel_input)
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
elif ml_cfg is not None: elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms) all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
@@ -341,6 +382,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML") is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
for batch_start in range(0, len(valid_times), batch_size): for batch_start in range(0, len(valid_times), batch_size):
@@ -369,6 +411,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
else: else:
out = model(input_values) out = model(input_values)
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy() batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
elif is_eat:
mel_input = _eat_preprocess(chunks, sr, device)
features = model.extract_features(mel_input)
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
elif ml_cfg is not None: elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms) all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
+5
View File
@@ -53,6 +53,11 @@ def test_embed_dim_ast():
assert _embed_dim("AST_ML") == 3072 assert _embed_dim("AST_ML") == 3072
def test_embed_dim_eat():
from core.audio_scan import _embed_dim
assert _embed_dim("EAT") == 768
def test_db_get_all_export_paths(): def test_db_get_all_export_paths():
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name path = f.name