diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 965d3ad..804b012 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -125,8 +125,9 @@ class _DiscriminatorR(nn.Module): x = x.squeeze(1) # [B, T] pad = (win - hop) // 2 x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect") + orig_dtype = x.dtype x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True) - x = x.abs().unsqueeze(1) # [B, 1, freq, time] + x = x.abs().to(orig_dtype).unsqueeze(1) # [B, 1, freq, time] return x def forward(self, x):