From 082a2da438706f3505bf713564553107df484f55 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 12:13:55 +0200 Subject: [PATCH] fix: restore dtype after float32 STFT in discriminator spectrogram torch.stft requires float32 input, but the .float() cast was not reversed before the spectrogram hit bfloat16 Conv2d weights. Save the original dtype and cast back after abs(). Co-Authored-By: Claude Opus 4.6 --- nodes/selva_bigvgan_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 965d3ad..804b012 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -125,8 +125,9 @@ class _DiscriminatorR(nn.Module): x = x.squeeze(1) # [B, T] pad = (win - hop) // 2 x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect") + orig_dtype = x.dtype x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True) - x = x.abs().unsqueeze(1) # [B, 1, freq, time] + x = x.abs().to(orig_dtype).unsqueeze(1) # [B, 1, freq, time] return x def forward(self, x):