fix: trim/pad latent to seq_cfg.latent_seq_len before decoding

Without this the decoder produced 7s instead of 8s due to STFT rounding.
Same fix as _prepare_dataset uses for training data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 19:22:09 +02:00
parent 8195c3114a
commit 528d33be39
+12 -2
View File
@@ -96,8 +96,18 @@ class SelvaVaeRoundtrip:
with torch.no_grad():
# Encode
dist = vae.encode_audio(wav_b)
latent = dist.mode().clone() # [1, latent_dim, T]
dist = vae.encode_audio(wav_b)
latent = dist.mode().clone() # [1, latent_dim, T]
# Trim/pad latent to the exact model sequence length
# (same as _prepare_dataset) so the decoder produces the right duration
tgt = seq_cfg.latent_seq_len
if latent.shape[2] < tgt:
import torch.nn.functional as F
latent = F.pad(latent, (0, tgt - latent.shape[2]))
elif latent.shape[2] > tgt:
latent = latent[:, :, :tgt]
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)