fix: avoid unnecessary GPU tensor allocation for AST/EAT models

Move waveforms creation inside the else branch so AST and EAT
models (which have their own preprocessing) don't waste GPU memory.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 14:53:05 +02:00
parent e7b791fbfa
commit c020c0dfec
+26 -24
View File
@@ -280,12 +280,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
start = i * hop_samples start = i * hop_samples
chunks.append(y[start:start + win_samples]) chunks.append(y[start:start + win_samples])
with torch.no_grad(): with torch.no_grad():
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) if is_ast:
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif is_ast:
inputs = _ast_feature_extractor( inputs = _ast_feature_extractor(
list(chunks), sampling_rate=sr, return_tensors="pt", list(chunks), sampling_rate=sr, return_tensors="pt",
padding=True, padding=True,
@@ -302,13 +297,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
mel_input = _eat_preprocess(chunks, sr, device) mel_input = _eat_preprocess(chunks, sr, device)
features = model.extract_features(mel_input) features = model.extract_features(mel_input)
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy() batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else: else:
features, _ = model(waveforms) waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
batch_emb = features.mean(dim=1).cpu().numpy() if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings.append(batch_emb) embeddings.append(batch_emb)
result_ts = timestamps result_ts = timestamps
@@ -393,12 +394,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
chunks.append(y[start:start + win_samples]) chunks.append(y[start:start + win_samples])
timestamps_list.append(float(t)) timestamps_list.append(float(t))
with torch.no_grad(): with torch.no_grad():
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) if is_ast:
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif is_ast:
inputs = _ast_feature_extractor( inputs = _ast_feature_extractor(
list(chunks), sampling_rate=sr, return_tensors="pt", list(chunks), sampling_rate=sr, return_tensors="pt",
padding=True, padding=True,
@@ -415,13 +411,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
mel_input = _eat_preprocess(chunks, sr, device) mel_input = _eat_preprocess(chunks, sr, device)
features = model.extract_features(mel_input) features = model.extract_features(mel_input)
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy() batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else: else:
features, _ = model(waveforms) waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
batch_emb = features.mean(dim=1).cpu().numpy() if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
batch_emb = features.mean(dim=1).cpu().numpy()
elif ml_cfg is not None:
all_layers, _ = model.extract_features(waveforms)
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings_list.append(batch_emb) embeddings_list.append(batch_emb)
timestamps = np.array(timestamps_list) timestamps = np.array(timestamps_list)