fix: trim/pad latent to seq_cfg.latent_seq_len before decoding
Without this the decoder produced 7s instead of 8s due to STFT rounding. Same fix as _prepare_dataset uses for training data. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -96,8 +96,18 @@ class SelvaVaeRoundtrip:
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode
|
||||
dist = vae.encode_audio(wav_b)
|
||||
latent = dist.mode().clone() # [1, latent_dim, T]
|
||||
dist = vae.encode_audio(wav_b)
|
||||
latent = dist.mode().clone() # [1, latent_dim, T]
|
||||
|
||||
# Trim/pad latent to the exact model sequence length
|
||||
# (same as _prepare_dataset) so the decoder produces the right duration
|
||||
tgt = seq_cfg.latent_seq_len
|
||||
if latent.shape[2] < tgt:
|
||||
import torch.nn.functional as F
|
||||
latent = F.pad(latent, (0, tgt - latent.shape[2]))
|
||||
elif latent.shape[2] > tgt:
|
||||
latent = latent[:, :, :tgt]
|
||||
|
||||
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
|
||||
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user