diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index f00f5b4..8017ec2 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -245,7 +245,7 @@ class SelvaDittoOptimizer: if sr != sample_rate: wav = torchaudio.functional.resample(wav, sr, sample_rate) wav = wav.squeeze(0).to(device, torch.float32) - mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T_mel] + mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel] # encode → sample → normalize (matches x at ODE endpoint) z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]