diff --git a/nodes/selva_vae_roundtrip.py b/nodes/selva_vae_roundtrip.py index 0939126..5b1cff4 100644 --- a/nodes/selva_vae_roundtrip.py +++ b/nodes/selva_vae_roundtrip.py @@ -96,8 +96,18 @@ class SelvaVaeRoundtrip: with torch.no_grad(): # Encode - dist = vae.encode_audio(wav_b) - latent = dist.mode().clone() # [1, latent_dim, T] + dist = vae.encode_audio(wav_b) + latent = dist.mode().clone() # [1, latent_dim, T] + + # Trim/pad latent to the exact model sequence length + # (same as _prepare_dataset) so the decoder produces the right duration + tgt = seq_cfg.latent_seq_len + if latent.shape[2] < tgt: + import torch.nn.functional as F + latent = F.pad(latent, (0, tgt - latent.shape[2])) + elif latent.shape[2] > tgt: + latent = latent[:, :, :tgt] + print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} " f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)