fix: cast mel to model dtype before VAE encode in DITTO reference loading

mel_converter outputs float32 (cuFFT requirement), but VAE encoder weights
are bfloat16. Cast mel to dtype before encode to avoid type mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 18:18:41 +02:00
parent 056a7b973d
commit 286681edff
+1 -1
View File
@@ -245,7 +245,7 @@ class SelvaDittoOptimizer:
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, torch.float32)
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T_mel]
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
# encode → sample → normalize (matches x at ODE endpoint)
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]