fix: restore dtype after float32 STFT in discriminator spectrogram
torch.stft requires float32 input, but the .float() cast was not reversed before the spectrogram hit bfloat16 Conv2d weights. Save the original dtype and cast back after abs(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -125,8 +125,9 @@ class _DiscriminatorR(nn.Module):
|
||||
x = x.squeeze(1) # [B, T]
|
||||
pad = (win - hop) // 2
|
||||
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
|
||||
orig_dtype = x.dtype
|
||||
x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True)
|
||||
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
|
||||
x = x.abs().to(orig_dtype).unsqueeze(1) # [B, 1, freq, time]
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user