fix: cast mel to model dtype before VAE encode in DITTO reference loading
mel_converter outputs float32 (cuFFT requirement), but VAE encoder weights are bfloat16. Cast mel to dtype before encode to avoid type mismatch. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -245,7 +245,7 @@ class SelvaDittoOptimizer:
|
||||
if sr != sample_rate:
|
||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||
wav = wav.squeeze(0).to(device, torch.float32)
|
||||
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
|
||||
# encode → sample → normalize (matches x at ODE endpoint)
|
||||
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
|
||||
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
|
||||
|
||||
Reference in New Issue
Block a user