fix: guard _estimate_snr against short clips, fix freqs device in _check_hf_shelf

Bug 1: mono.unfold(0, 2048, 512) returns an empty tensor for clips shorter
than 2048 samples (~46ms). torch.quantile on an empty tensor crashes with
"quantile() input tensor must be non-empty". Guard: return 60.0 (assume
clean) for clips too short to frame — the pipeline has no minimum-length
filter so any short file in the dataset folder would crash the Inspector.

Bug 2: torch.linspace(...) in _check_hf_shelf created a CPU tensor, making
band_lo/band_hi CPU boolean masks. Indexing a GPU mag_sq tensor with CPU
masks crashes. Pass device=mono.device so freqs lands on the same device
as the audio.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 14:28:36 +02:00
parent 8a85819f97
commit f50afa9796
+3 -1
View File
@@ -205,7 +205,7 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs] mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1) freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device)
band_lo = (freqs >= 1000) & (freqs < 5000) band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000) band_hi = (freqs >= 15000) & (freqs < 20000)
@@ -221,6 +221,8 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
def _estimate_snr(wav: torch.Tensor) -> float: def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS.""" """Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L] mono = wav[0].mean(0) # [L]
if mono.shape[0] < 2048:
return 60.0 # clip too short to frame — assume clean
frames = mono.unfold(0, 2048, 512) # [N, 2048] frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N] rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item() p95 = torch.quantile(rms, 0.95).item()