feat: add EAT (Efficient Audio Transformer) embedding model
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -64,6 +64,7 @@ _EMBED_MODELS = {
|
|||||||
# Transformers-based models
|
# Transformers-based models
|
||||||
"AST": 768,
|
"AST": 768,
|
||||||
"AST_ML": 3072, # 768 * 4
|
"AST_ML": 3072, # 768 * 4
|
||||||
|
"EAT": 768,
|
||||||
}
|
}
|
||||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||||
|
|
||||||
@@ -103,6 +104,12 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
)
|
)
|
||||||
|
elif load_name == "EAT":
|
||||||
|
from transformers import AutoModel
|
||||||
|
_w2v_model = AutoModel.from_pretrained(
|
||||||
|
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||||
|
trust_remote_code=True,
|
||||||
|
).to(_w2v_device)
|
||||||
else:
|
else:
|
||||||
import torchaudio
|
import torchaudio
|
||||||
bundle = getattr(torchaudio.pipelines, load_name)
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
@@ -114,6 +121,35 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
return _w2v_model, _w2v_device
|
return _w2v_model, _w2v_device
|
||||||
|
|
||||||
|
|
||||||
|
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||||
|
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||||
|
|
||||||
|
Returns tensor of shape [B, 1, T, 128].
|
||||||
|
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
TARGET_LEN = 1024
|
||||||
|
MEAN, STD = -4.268, 4.569
|
||||||
|
|
||||||
|
mels = []
|
||||||
|
for chunk in chunks:
|
||||||
|
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||||
|
fbank = kaldi.fbank(
|
||||||
|
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||||
|
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||||
|
)
|
||||||
|
# Pad or truncate to TARGET_LEN
|
||||||
|
if fbank.shape[0] < TARGET_LEN:
|
||||||
|
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||||
|
else:
|
||||||
|
fbank = fbank[:TARGET_LEN]
|
||||||
|
fbank = (fbank - MEAN) / (STD * 2)
|
||||||
|
mels.append(fbank)
|
||||||
|
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||||
|
|
||||||
|
|
||||||
def _embed_dim(model_name: str | None = None) -> int:
|
def _embed_dim(model_name: str | None = None) -> int:
|
||||||
"""Return embedding dimension for a model name."""
|
"""Return embedding dimension for a model name."""
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
@@ -218,6 +254,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
model, device = _get_w2v_model(model_name)
|
model, device = _get_w2v_model(model_name)
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
# Auto-size batches based on available GPU memory
|
# Auto-size batches based on available GPU memory
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
@@ -261,6 +298,10 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
else:
|
else:
|
||||||
out = model(input_values)
|
out = model(input_values)
|
||||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
elif ml_cfg is not None:
|
elif ml_cfg is not None:
|
||||||
all_layers, _ = model.extract_features(waveforms)
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
@@ -341,6 +382,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
|
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
|
||||||
for batch_start in range(0, len(valid_times), batch_size):
|
for batch_start in range(0, len(valid_times), batch_size):
|
||||||
@@ -369,6 +411,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
else:
|
else:
|
||||||
out = model(input_values)
|
out = model(input_values)
|
||||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
elif ml_cfg is not None:
|
elif ml_cfg is not None:
|
||||||
all_layers, _ = model.extract_features(waveforms)
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
|||||||
@@ -53,6 +53,11 @@ def test_embed_dim_ast():
|
|||||||
assert _embed_dim("AST_ML") == 3072
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_eat():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("EAT") == 768
|
||||||
|
|
||||||
|
|
||||||
def test_db_get_all_export_paths():
|
def test_db_get_all_export_paths():
|
||||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
path = f.name
|
path = f.name
|
||||||
|
|||||||
Reference in New Issue
Block a user