From 528d33be390e0c48381f703fed8f6a8c67983497 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 19:22:09 +0200 Subject: [PATCH] fix: trim/pad latent to seq_cfg.latent_seq_len before decoding Without this the decoder produced 7s instead of 8s due to STFT rounding. Same fix as _prepare_dataset uses for training data. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_vae_roundtrip.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nodes/selva_vae_roundtrip.py b/nodes/selva_vae_roundtrip.py index 0939126..5b1cff4 100644 --- a/nodes/selva_vae_roundtrip.py +++ b/nodes/selva_vae_roundtrip.py @@ -96,8 +96,18 @@ class SelvaVaeRoundtrip: with torch.no_grad(): # Encode - dist = vae.encode_audio(wav_b) - latent = dist.mode().clone() # [1, latent_dim, T] + dist = vae.encode_audio(wav_b) + latent = dist.mode().clone() # [1, latent_dim, T] + + # Trim/pad latent to the exact model sequence length + # (same as _prepare_dataset) so the decoder produces the right duration + tgt = seq_cfg.latent_seq_len + if latent.shape[2] < tgt: + import torch.nn.functional as F + latent = F.pad(latent, (0, tgt - latent.shape[2])) + elif latent.shape[2] > tgt: + latent = latent[:, :, :tgt] + print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} " f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)