From f50afa9796b7ce3a03e2f3fc3412ff7f24ab56bd Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 14:28:36 +0200 Subject: [PATCH] fix: guard _estimate_snr against short clips, fix freqs device in _check_hf_shelf MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug 1: mono.unfold(0, 2048, 512) returns an empty tensor for clips shorter than 2048 samples (~46ms). torch.quantile on an empty tensor crashes with "quantile() input tensor must be non-empty". Guard: return 60.0 (assume clean) for clips too short to frame — the pipeline has no minimum-length filter so any short file in the dataset folder would crash the Inspector. Bug 2: torch.linspace(...) in _check_hf_shelf created a CPU tensor, making band_lo/band_hi CPU boolean masks. Indexing a GPU mag_sq tensor with CPU masks crashes. Pass device=mono.device so freqs lands on the same device as the audio. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_dataset_pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index 85242b9..86772f2 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -205,7 +205,7 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs] - freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1) + freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device) band_lo = (freqs >= 1000) & (freqs < 5000) band_hi = (freqs >= 15000) & (freqs < 20000) @@ -221,6 +221,8 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: def _estimate_snr(wav: torch.Tensor) -> float: """Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS.""" mono = wav[0].mean(0) # [L] + if mono.shape[0] < 2048: + return 60.0 # clip too short to frame — assume clean frames = mono.unfold(0, 2048, 512) # [N, 2048] rms = frames.pow(2).mean(-1).sqrt() # [N] p95 = torch.quantile(rms, 0.95).item()