fix: restore dtype after float32 STFT in discriminator spectrogram

torch.stft requires float32 input, but the .float() cast was not
reversed before the spectrogram hit bfloat16 Conv2d weights. Save
the original dtype and cast back after abs().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 12:13:55 +02:00
parent c28e090196
commit 082a2da438
+2 -1
View File
@@ -125,8 +125,9 @@ class _DiscriminatorR(nn.Module):
x = x.squeeze(1) # [B, T]
pad = (win - hop) // 2
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
orig_dtype = x.dtype
x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True)
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
x = x.abs().to(orig_dtype).unsqueeze(1) # [B, 1, freq, time]
return x
def forward(self, x):