diff --git a/core/audio_scan.py b/core/audio_scan.py index 2267abe..d41dfbe 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -280,12 +280,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, start = i * hop_samples chunks.append(y[start:start + win_samples]) with torch.no_grad(): - waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) - if is_beats: - padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) - features, _ = model.extract_features(waveforms, padding_mask=padding_mask) - batch_emb = features.mean(dim=1).cpu().numpy() - elif is_ast: + if is_ast: inputs = _ast_feature_extractor( list(chunks), sampling_rate=sr, return_tensors="pt", padding=True, @@ -302,13 +297,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, mel_input = _eat_preprocess(chunks, sr, device) features = model.extract_features(mel_input) batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy() - elif ml_cfg is not None: - all_layers, _ = model.extract_features(waveforms) - selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] - batch_emb = torch.cat(selected, dim=1).cpu().numpy() else: - features, _ = model(waveforms) - batch_emb = features.mean(dim=1).cpu().numpy() + waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) + if is_beats: + padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) + features, _ = model.extract_features(waveforms, padding_mask=padding_mask) + batch_emb = features.mean(dim=1).cpu().numpy() + elif ml_cfg is not None: + all_layers, _ = model.extract_features(waveforms) + selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] + batch_emb = torch.cat(selected, dim=1).cpu().numpy() + else: + features, _ = model(waveforms) + batch_emb = features.mean(dim=1).cpu().numpy() embeddings.append(batch_emb) result_ts = timestamps @@ -393,12 +394,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], chunks.append(y[start:start + win_samples]) timestamps_list.append(float(t)) with torch.no_grad(): - waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) - if is_beats: - padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) - features, _ = model.extract_features(waveforms, padding_mask=padding_mask) - batch_emb = features.mean(dim=1).cpu().numpy() - elif is_ast: + if is_ast: inputs = _ast_feature_extractor( list(chunks), sampling_rate=sr, return_tensors="pt", padding=True, @@ -415,13 +411,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], mel_input = _eat_preprocess(chunks, sr, device) features = model.extract_features(mel_input) batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy() - elif ml_cfg is not None: - all_layers, _ = model.extract_features(waveforms) - selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] - batch_emb = torch.cat(selected, dim=1).cpu().numpy() else: - features, _ = model(waveforms) - batch_emb = features.mean(dim=1).cpu().numpy() + waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) + if is_beats: + padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) + features, _ = model.extract_features(waveforms, padding_mask=padding_mask) + batch_emb = features.mean(dim=1).cpu().numpy() + elif ml_cfg is not None: + all_layers, _ = model.extract_features(waveforms) + selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]] + batch_emb = torch.cat(selected, dim=1).cpu().numpy() + else: + features, _ = model(waveforms) + batch_emb = features.mean(dim=1).cpu().numpy() embeddings_list.append(batch_emb) timestamps = np.array(timestamps_list)