feat: multi-layer extraction for HuBERT/Wav2Vec2 models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 13:53:55 +02:00
parent 4736f150b0
commit e3f133ef84
2 changed files with 74 additions and 7 deletions
+50 -5
View File
@@ -55,6 +55,11 @@ _EMBED_MODELS = {
"HUBERT_LARGE": 1024,
"HUBERT_XLARGE": 1280,
"BEATS": 768,
# Multi-layer variants (4 quartile layers concatenated)
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
"HUBERT_BASE_ML": 3072, # 768 * 4
"HUBERT_LARGE_ML": 4096, # 1024 * 4
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
}
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
@@ -70,11 +75,14 @@ def _get_w2v_model(model_name: str | None = None):
global _w2v_model, _w2v_device, _w2v_model_name
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
if _w2v_model is None or _w2v_model_name != model_name:
# Multi-layer variants use the same base model weights
ml = _ml_config(model_name)
load_name = ml[0] if ml else model_name
if _w2v_model is None or _w2v_model_name != load_name:
import torch
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "BEATS":
if load_name == "BEATS":
from .beats_model import BEATs, BEATsConfig
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
weights_only=False)
@@ -84,12 +92,12 @@ def _get_w2v_model(model_name: str | None = None):
_w2v_model.to(_w2v_device)
else:
import torchaudio
bundle = getattr(torchaudio.pipelines, model_name)
bundle = getattr(torchaudio.pipelines, load_name)
_w2v_model = bundle.get_model().to(_w2v_device)
_w2v_model.eval()
_w2v_model_name = model_name
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
_w2v_model_name = load_name
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
return _w2v_model, _w2v_device
@@ -100,6 +108,31 @@ def _embed_dim(model_name: str | None = None) -> int:
return _EMBED_MODELS.get(model_name, 768)
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
Returns None for single-layer models.
Layer indices are 0-based into the list returned by extract_features().
"""
if not model_name.endswith("_ML"):
return None
base = model_name[:-3] # strip "_ML"
if base not in _EMBED_MODELS:
return None
# Layer counts per model family
layer_counts = {
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
"AST": 12,
}
n = layer_counts.get(base)
if n is None:
return None
# Select 4 layers at quartile boundaries (0-indexed)
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
return base, indices
def _w2v_cache_path(video_path: str, hop: float, window: float,
model_name: str | None = None) -> str:
"""Return cache file path for a video's embeddings (includes model name)."""
@@ -171,6 +204,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
import torch
model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
# Auto-size batches based on available GPU memory
batch_size = 16
if device == "cuda":
@@ -199,6 +233,11 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
@@ -274,6 +313,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
embeddings_list: list[np.ndarray] = []
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
for batch_start in range(0, len(valid_times), batch_size):
batch_end = min(batch_start + batch_size, len(valid_times))
@@ -287,6 +327,11 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
+22
View File
@@ -25,6 +25,28 @@ def test_default_model_path_contains_profile():
assert path.endswith(".joblib")
def test_embed_dim_multi_layer():
from core.audio_scan import _embed_dim
# Multi-layer models should report concatenated dimension
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
assert _embed_dim("HUBERT_LARGE_ML") == 4096
assert _embed_dim("HUBERT_BASE_ML") == 3072
# Single-layer unchanged
assert _embed_dim("HUBERT_XLARGE") == 1280
def test_ml_config():
from core.audio_scan import _ml_config
assert _ml_config("HUBERT_XLARGE") is None
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
base, layers = _ml_config("HUBERT_XLARGE_ML")
assert base == "HUBERT_XLARGE"
assert layers == [11, 23, 35, 47]
base, layers = _ml_config("HUBERT_BASE_ML")
assert base == "HUBERT_BASE"
assert layers == [2, 5, 8, 11]
def test_db_get_all_export_paths():
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name