diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index 85242b9..86772f2 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -205,7 +205,7 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs] - freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1) + freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device) band_lo = (freqs >= 1000) & (freqs < 5000) band_hi = (freqs >= 15000) & (freqs < 20000) @@ -221,6 +221,8 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: def _estimate_snr(wav: torch.Tensor) -> float: """Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS.""" mono = wav[0].mean(0) # [L] + if mono.shape[0] < 2048: + return 60.0 # clip too short to frame — assume clean frames = mono.unfold(0, 2048, 512) # [N, 2048] rms = frames.pow(2).mean(-1).sqrt() # [N] p95 = torch.quantile(rms, 0.95).item()