diff --git a/core/audio_scan.py b/core/audio_scan.py index 7d7fda3..ad21caf 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -55,6 +55,11 @@ _EMBED_MODELS = { "HUBERT_LARGE": 1024, "HUBERT_XLARGE": 1280, "BEATS": 768, + # Multi-layer variants (4 quartile layers concatenated) + "WAV2VEC2_BASE_ML": 3072, # 768 * 4 + "HUBERT_BASE_ML": 3072, # 768 * 4 + "HUBERT_LARGE_ML": 4096, # 1024 * 4 + "HUBERT_XLARGE_ML": 5120, # 1280 * 4 } _DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" @@ -70,11 +75,14 @@ def _get_w2v_model(model_name: str | None = None): global _w2v_model, _w2v_device, _w2v_model_name if model_name is None: model_name = _DEFAULT_EMBED_MODEL - if _w2v_model is None or _w2v_model_name != model_name: + # Multi-layer variants use the same base model weights + ml = _ml_config(model_name) + load_name = ml[0] if ml else model_name + if _w2v_model is None or _w2v_model_name != load_name: import torch _w2v_device = "cuda" if torch.cuda.is_available() else "cpu" - if model_name == "BEATS": + if load_name == "BEATS": from .beats_model import BEATs, BEATsConfig checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device, weights_only=False) @@ -84,12 +92,12 @@ def _get_w2v_model(model_name: str | None = None): _w2v_model.to(_w2v_device) else: import torchaudio - bundle = getattr(torchaudio.pipelines, model_name) + bundle = getattr(torchaudio.pipelines, load_name) _w2v_model = bundle.get_model().to(_w2v_device) _w2v_model.eval() - _w2v_model_name = model_name - _log(f"audio_scan: {model_name} loaded on {_w2v_device}") + _w2v_model_name = load_name + _log(f"audio_scan: {load_name} loaded on {_w2v_device}") return _w2v_model, _w2v_device @@ -100,6 +108,31 @@ def _embed_dim(model_name: str | None = None) -> int: return _EMBED_MODELS.get(model_name, 768) +def _ml_config(model_name: str) -> tuple[str, list[int]] | None: + """If model_name is a multi-layer variant, return (base_model, layer_indices). + + Returns None for single-layer models. + Layer indices are 0-based into the list returned by extract_features(). + """ + if not model_name.endswith("_ML"): + return None + base = model_name[:-3] # strip "_ML" + if base not in _EMBED_MODELS: + return None + # Layer counts per model family + layer_counts = { + "WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24, + "HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48, + "AST": 12, + } + n = layer_counts.get(base) + if n is None: + return None + # Select 4 layers at quartile boundaries (0-indexed) + indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1] + return base, indices + + def _w2v_cache_path(video_path: str, hop: float, window: float, model_name: str | None = None) -> str: """Return cache file path for a video's embeddings (includes model name).""" @@ -171,6 +204,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, import torch model, device = _get_w2v_model(model_name) is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) # Auto-size batches based on available GPU memory batch_size = 16 if device == "cuda": @@ -199,9 +233,14 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, if is_beats: padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) features, _ = model.extract_features(waveforms, padding_mask=padding_mask) + batch_emb = features.mean(dim=1).cpu().numpy() + elif ml_cfg is not None: + all_layers, _ = model.extract_features(waveforms) + selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] + batch_emb = torch.cat(selected, dim=1).cpu().numpy() else: features, _ = model(waveforms) - batch_emb = features.mean(dim=1).cpu().numpy() + batch_emb = features.mean(dim=1).cpu().numpy() embeddings.append(batch_emb) result_ts = timestamps @@ -274,6 +313,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], embeddings_list: list[np.ndarray] = [] is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL) for batch_start in range(0, len(valid_times), batch_size): batch_end = min(batch_start + batch_size, len(valid_times)) @@ -287,9 +327,14 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], if is_beats: padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) features, _ = model.extract_features(waveforms, padding_mask=padding_mask) + batch_emb = features.mean(dim=1).cpu().numpy() + elif ml_cfg is not None: + all_layers, _ = model.extract_features(waveforms) + selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] + batch_emb = torch.cat(selected, dim=1).cpu().numpy() else: features, _ = model(waveforms) - batch_emb = features.mean(dim=1).cpu().numpy() + batch_emb = features.mean(dim=1).cpu().numpy() embeddings_list.append(batch_emb) timestamps = np.array(timestamps_list) diff --git a/tests/test_audio_scan.py b/tests/test_audio_scan.py index 1335527..38c7f00 100644 --- a/tests/test_audio_scan.py +++ b/tests/test_audio_scan.py @@ -25,6 +25,28 @@ def test_default_model_path_contains_profile(): assert path.endswith(".joblib") +def test_embed_dim_multi_layer(): + from core.audio_scan import _embed_dim + # Multi-layer models should report concatenated dimension + assert _embed_dim("HUBERT_XLARGE_ML") == 5120 + assert _embed_dim("HUBERT_LARGE_ML") == 4096 + assert _embed_dim("HUBERT_BASE_ML") == 3072 + # Single-layer unchanged + assert _embed_dim("HUBERT_XLARGE") == 1280 + + +def test_ml_config(): + from core.audio_scan import _ml_config + assert _ml_config("HUBERT_XLARGE") is None + assert _ml_config("BEATS_ML") is None # BEATS has no ML variant + base, layers = _ml_config("HUBERT_XLARGE_ML") + assert base == "HUBERT_XLARGE" + assert layers == [11, 23, 35, 47] + base, layers = _ml_config("HUBERT_BASE_ML") + assert base == "HUBERT_BASE" + assert layers == [2, 5, 8, 11] + + def test_db_get_all_export_paths(): with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: path = f.name