feat: multi-layer extraction for HuBERT/Wav2Vec2 models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+52
-7
@@ -55,6 +55,11 @@ _EMBED_MODELS = {
|
||||
"HUBERT_LARGE": 1024,
|
||||
"HUBERT_XLARGE": 1280,
|
||||
"BEATS": 768,
|
||||
# Multi-layer variants (4 quartile layers concatenated)
|
||||
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||
}
|
||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||
|
||||
@@ -70,11 +75,14 @@ def _get_w2v_model(model_name: str | None = None):
|
||||
global _w2v_model, _w2v_device, _w2v_model_name
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
if _w2v_model is None or _w2v_model_name != model_name:
|
||||
# Multi-layer variants use the same base model weights
|
||||
ml = _ml_config(model_name)
|
||||
load_name = ml[0] if ml else model_name
|
||||
if _w2v_model is None or _w2v_model_name != load_name:
|
||||
import torch
|
||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if model_name == "BEATS":
|
||||
if load_name == "BEATS":
|
||||
from .beats_model import BEATs, BEATsConfig
|
||||
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||
weights_only=False)
|
||||
@@ -84,12 +92,12 @@ def _get_w2v_model(model_name: str | None = None):
|
||||
_w2v_model.to(_w2v_device)
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, model_name)
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||
|
||||
_w2v_model.eval()
|
||||
_w2v_model_name = model_name
|
||||
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
|
||||
_w2v_model_name = load_name
|
||||
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||
return _w2v_model, _w2v_device
|
||||
|
||||
|
||||
@@ -100,6 +108,31 @@ def _embed_dim(model_name: str | None = None) -> int:
|
||||
return _EMBED_MODELS.get(model_name, 768)
|
||||
|
||||
|
||||
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||
|
||||
Returns None for single-layer models.
|
||||
Layer indices are 0-based into the list returned by extract_features().
|
||||
"""
|
||||
if not model_name.endswith("_ML"):
|
||||
return None
|
||||
base = model_name[:-3] # strip "_ML"
|
||||
if base not in _EMBED_MODELS:
|
||||
return None
|
||||
# Layer counts per model family
|
||||
layer_counts = {
|
||||
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||
"AST": 12,
|
||||
}
|
||||
n = layer_counts.get(base)
|
||||
if n is None:
|
||||
return None
|
||||
# Select 4 layers at quartile boundaries (0-indexed)
|
||||
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||
return base, indices
|
||||
|
||||
|
||||
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> str:
|
||||
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||
@@ -171,6 +204,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
import torch
|
||||
model, device = _get_w2v_model(model_name)
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
# Auto-size batches based on available GPU memory
|
||||
batch_size = 16
|
||||
if device == "cuda":
|
||||
@@ -199,9 +233,14 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings.append(batch_emb)
|
||||
|
||||
result_ts = timestamps
|
||||
@@ -274,6 +313,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
embeddings_list: list[np.ndarray] = []
|
||||
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
|
||||
for batch_start in range(0, len(valid_times), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||
@@ -287,9 +327,14 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings_list.append(batch_emb)
|
||||
|
||||
timestamps = np.array(timestamps_list)
|
||||
|
||||
@@ -25,6 +25,28 @@ def test_default_model_path_contains_profile():
|
||||
assert path.endswith(".joblib")
|
||||
|
||||
|
||||
def test_embed_dim_multi_layer():
|
||||
from core.audio_scan import _embed_dim
|
||||
# Multi-layer models should report concatenated dimension
|
||||
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||
# Single-layer unchanged
|
||||
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||
|
||||
|
||||
def test_ml_config():
|
||||
from core.audio_scan import _ml_config
|
||||
assert _ml_config("HUBERT_XLARGE") is None
|
||||
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||
assert base == "HUBERT_XLARGE"
|
||||
assert layers == [11, 23, 35, 47]
|
||||
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||
assert base == "HUBERT_BASE"
|
||||
assert layers == [2, 5, 8, 11]
|
||||
|
||||
|
||||
def test_db_get_all_export_paths():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
Reference in New Issue
Block a user