feat: multi-layer extraction for HuBERT/Wav2Vec2 models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 13:53:55 +02:00
parent 4736f150b0
commit e3f133ef84
2 changed files with 74 additions and 7 deletions
+52 -7
View File
@@ -55,6 +55,11 @@ _EMBED_MODELS = {
"HUBERT_LARGE": 1024,
"HUBERT_XLARGE": 1280,
"BEATS": 768,
# Multi-layer variants (4 quartile layers concatenated)
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
"HUBERT_BASE_ML": 3072, # 768 * 4
"HUBERT_LARGE_ML": 4096, # 1024 * 4
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
}
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
@@ -70,11 +75,14 @@ def _get_w2v_model(model_name: str | None = None):
global _w2v_model, _w2v_device, _w2v_model_name
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
if _w2v_model is None or _w2v_model_name != model_name:
# Multi-layer variants use the same base model weights
ml = _ml_config(model_name)
load_name = ml[0] if ml else model_name
if _w2v_model is None or _w2v_model_name != load_name:
import torch
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "BEATS":
if load_name == "BEATS":
from .beats_model import BEATs, BEATsConfig
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
weights_only=False)
@@ -84,12 +92,12 @@ def _get_w2v_model(model_name: str | None = None):
_w2v_model.to(_w2v_device)
else:
import torchaudio
bundle = getattr(torchaudio.pipelines, model_name)
bundle = getattr(torchaudio.pipelines, load_name)
_w2v_model = bundle.get_model().to(_w2v_device)
_w2v_model.eval()
_w2v_model_name = model_name
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
_w2v_model_name = load_name
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
return _w2v_model, _w2v_device
@@ -100,6 +108,31 @@ def _embed_dim(model_name: str | None = None) -> int:
return _EMBED_MODELS.get(model_name, 768)
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
Returns None for single-layer models.
Layer indices are 0-based into the list returned by extract_features().
"""
if not model_name.endswith("_ML"):
return None
base = model_name[:-3] # strip "_ML"
if base not in _EMBED_MODELS:
return None
# Layer counts per model family
layer_counts = {
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
"AST": 12,
}
n = layer_counts.get(base)
if n is None:
return None
# Select 4 layers at quartile boundaries (0-indexed)
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
return base, indices
def _w2v_cache_path(video_path: str, hop: float, window: float,
model_name: str | None = None) -> str:
"""Return cache file path for a video's embeddings (includes model name)."""
@@ -171,6 +204,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
import torch
model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
# Auto-size batches based on available GPU memory
batch_size = 16
if device == "cuda":
@@ -199,9 +233,14 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings.append(batch_emb)
result_ts = timestamps
@@ -274,6 +313,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
embeddings_list: list[np.ndarray] = []
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
for batch_start in range(0, len(valid_times), batch_size):
batch_end = min(batch_start + batch_size, len(valid_times))
@@ -287,9 +327,14 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings_list.append(batch_emb)
timestamps = np.array(timestamps_list)