fix: correct DITTO reference latent space mismatch

References were stored in normalized flow-matching space
(net_generator.normalize(z_sample)) but the style loss compares against
unnormalize(x) which is in VAE latent space. The optimizer was minimizing
L1 between tensors at different scales, pushing the ODE endpoint out of
distribution and producing noise.

Fix: store reference latents in VAE space (z_sample directly) so both
ref_mean/ref_gram and x_un are in the same coordinate system.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 18:57:08 +02:00
parent 14fabf01f9
commit 8fa2699551
+2 -3
View File
@@ -253,11 +253,10 @@ class SelvaDittoOptimizer:
wav = torchaudio.functional.resample(wav, sr, sample_rate) wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, torch.float32) wav = wav.squeeze(0).to(device, torch.float32)
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel] mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
# encode → sample → normalize (matches x at ODE endpoint) # encode → sample → VAE latent space (matches unnormalize(x) in loss)
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat] z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
z_norm = net_generator.normalize(z_sample.to(dtype)) ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
ref_latents.append(z_norm.squeeze(0).clone()) # [T_lat, C_lat]
except Exception as e: except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True) print(f" [DITTO] Skip {rf.name}: {e}", flush=True)