fix: avoid unnecessary GPU tensor allocation for AST/EAT models
Move waveforms creation inside the else branch so AST and EAT models (which have their own preprocessing) don't waste GPU memory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+26
-24
@@ -280,12 +280,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
start = i * hop_samples
|
||||
chunks.append(y[start:start + win_samples])
|
||||
with torch.no_grad():
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif is_ast:
|
||||
if is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
@@ -302,13 +297,19 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings.append(batch_emb)
|
||||
|
||||
result_ts = timestamps
|
||||
@@ -393,12 +394,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
chunks.append(y[start:start + win_samples])
|
||||
timestamps_list.append(float(t))
|
||||
with torch.no_grad():
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif is_ast:
|
||||
if is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
@@ -415,13 +411,19 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings_list.append(batch_emb)
|
||||
|
||||
timestamps = np.array(timestamps_list)
|
||||
|
||||
Reference in New Issue
Block a user