fix: avoid unnecessary GPU tensor allocation for AST/EAT models
Move waveforms creation inside the else branch so AST and EAT models (which have their own preprocessing) don't waste GPU memory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+14
-12
@@ -280,12 +280,7 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
start = i * hop_samples
|
start = i * hop_samples
|
||||||
chunks.append(y[start:start + win_samples])
|
chunks.append(y[start:start + win_samples])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
if is_ast:
|
||||||
if is_beats:
|
|
||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
|
||||||
elif is_ast:
|
|
||||||
inputs = _ast_feature_extractor(
|
inputs = _ast_feature_extractor(
|
||||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -302,6 +297,12 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
mel_input = _eat_preprocess(chunks, sr, device)
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
features = model.extract_features(mel_input)
|
features = model.extract_features(mel_input)
|
||||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
elif ml_cfg is not None:
|
elif ml_cfg is not None:
|
||||||
all_layers, _ = model.extract_features(waveforms)
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
@@ -393,12 +394,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
chunks.append(y[start:start + win_samples])
|
chunks.append(y[start:start + win_samples])
|
||||||
timestamps_list.append(float(t))
|
timestamps_list.append(float(t))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
if is_ast:
|
||||||
if is_beats:
|
|
||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
|
||||||
elif is_ast:
|
|
||||||
inputs = _ast_feature_extractor(
|
inputs = _ast_feature_extractor(
|
||||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -415,6 +411,12 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
mel_input = _eat_preprocess(chunks, sr, device)
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
features = model.extract_features(mel_input)
|
features = model.extract_features(mel_input)
|
||||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
elif ml_cfg is not None:
|
elif ml_cfg is not None:
|
||||||
all_layers, _ = model.extract_features(waveforms)
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
|||||||
Reference in New Issue
Block a user