fix: pass device to hann_window in _check_hf_shelf to avoid GPU mismatch

This commit is contained in:
2026-04-09 14:22:13 +02:00
parent 0731addea9
commit 2d06cb2f52
+1 -1
View File
@@ -200,8 +200,8 @@ def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
n_fft = 2048 n_fft = 2048
hop = 512 hop = 512
window = torch.hann_window(n_fft)
mono = wav[0].mean(0) # [L] mono = wav[0].mean(0) # [L]
window = torch.hann_window(n_fft, device=mono.device)
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs] mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]