fix: cast mel_converter and wav to float32 before cuFFT in DITTO

cuFFT does not support bfloat16. mel_converter was being moved to device
without an explicit dtype, inheriting bfloat16 from the model context.
Force float32 for both mel_converter.to() and wav.to() so the STFT
inside the mel converter runs in a supported dtype.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 15:59:55 +02:00
parent 15fc5f0793
commit 82e449681c
+2 -2
View File
@@ -191,7 +191,7 @@ class SelvaDittoOptimizer:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}") raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True) print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device) mel_converter.to(device, torch.float32) # cuFFT requires float32
ref_mels = [] ref_mels = []
with torch.no_grad(): with torch.no_grad():
@@ -202,7 +202,7 @@ class SelvaDittoOptimizer:
wav = wav.mean(0, keepdim=True) wav = wav.mean(0, keepdim=True)
if sr != sample_rate: if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate) wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, dtype) wav = wav.squeeze(0).to(device, torch.float32) # cuFFT requires float32
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T] mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
ref_mels.append(mel) ref_mels.append(mel)
except Exception as e: except Exception as e: