fix: cast mel_converter and wav to float32 before cuFFT in DITTO
cuFFT does not support bfloat16. mel_converter was being moved to device without an explicit dtype, inheriting bfloat16 from the model context. Force float32 for both mel_converter.to() and wav.to() so the STFT inside the mel converter runs in a supported dtype. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -191,7 +191,7 @@ class SelvaDittoOptimizer:
|
|||||||
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
||||||
|
|
||||||
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
||||||
mel_converter.to(device)
|
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
||||||
|
|
||||||
ref_mels = []
|
ref_mels = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -202,7 +202,7 @@ class SelvaDittoOptimizer:
|
|||||||
wav = wav.mean(0, keepdim=True)
|
wav = wav.mean(0, keepdim=True)
|
||||||
if sr != sample_rate:
|
if sr != sample_rate:
|
||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||||
wav = wav.squeeze(0).to(device, dtype)
|
wav = wav.squeeze(0).to(device, torch.float32) # cuFFT requires float32
|
||||||
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
|
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
|
||||||
ref_mels.append(mel)
|
ref_mels.append(mel)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user