feat: add EAT (Efficient Audio Transformer) embedding model
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -64,6 +64,7 @@ _EMBED_MODELS = {
|
||||
# Transformers-based models
|
||||
"AST": 768,
|
||||
"AST_ML": 3072, # 768 * 4
|
||||
"EAT": 768,
|
||||
}
|
||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||
|
||||
@@ -103,6 +104,12 @@ def _get_w2v_model(model_name: str | None = None):
|
||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
)
|
||||
elif load_name == "EAT":
|
||||
from transformers import AutoModel
|
||||
_w2v_model = AutoModel.from_pretrained(
|
||||
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||
trust_remote_code=True,
|
||||
).to(_w2v_device)
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
@@ -114,6 +121,35 @@ def _get_w2v_model(model_name: str | None = None):
|
||||
return _w2v_model, _w2v_device
|
||||
|
||||
|
||||
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||
|
||||
Returns tensor of shape [B, 1, T, 128].
|
||||
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||
"""
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
TARGET_LEN = 1024
|
||||
MEAN, STD = -4.268, 4.569
|
||||
|
||||
mels = []
|
||||
for chunk in chunks:
|
||||
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||
fbank = kaldi.fbank(
|
||||
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||
)
|
||||
# Pad or truncate to TARGET_LEN
|
||||
if fbank.shape[0] < TARGET_LEN:
|
||||
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||
else:
|
||||
fbank = fbank[:TARGET_LEN]
|
||||
fbank = (fbank - MEAN) / (STD * 2)
|
||||
mels.append(fbank)
|
||||
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||
|
||||
|
||||
def _embed_dim(model_name: str | None = None) -> int:
|
||||
"""Return embedding dimension for a model name."""
|
||||
if model_name is None:
|
||||
@@ -218,6 +254,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
model, device = _get_w2v_model(model_name)
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
# Auto-size batches based on available GPU memory
|
||||
batch_size = 16
|
||||
@@ -261,6 +298,10 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
@@ -341,6 +382,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
|
||||
for batch_start in range(0, len(valid_times), batch_size):
|
||||
@@ -369,6 +411,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
|
||||
@@ -53,6 +53,11 @@ def test_embed_dim_ast():
|
||||
assert _embed_dim("AST_ML") == 3072
|
||||
|
||||
|
||||
def test_embed_dim_eat():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("EAT") == 768
|
||||
|
||||
|
||||
def test_db_get_all_export_paths():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
Reference in New Issue
Block a user