diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index 398f8df..5bc354d 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -200,8 +200,8 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: n_fft = 2048 hop = 512 - window = torch.hann_window(n_fft) mono = wav[0].mean(0) # [L] + window = torch.hann_window(n_fft, device=mono.device) stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]