fix: trim/pad latent to seq_cfg.latent_seq_len before decoding
Without this the decoder produced 7s instead of 8s due to STFT rounding. Same fix as _prepare_dataset uses for training data. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -98,6 +98,16 @@ class SelvaVaeRoundtrip:
|
|||||||
# Encode
|
# Encode
|
||||||
dist = vae.encode_audio(wav_b)
|
dist = vae.encode_audio(wav_b)
|
||||||
latent = dist.mode().clone() # [1, latent_dim, T]
|
latent = dist.mode().clone() # [1, latent_dim, T]
|
||||||
|
|
||||||
|
# Trim/pad latent to the exact model sequence length
|
||||||
|
# (same as _prepare_dataset) so the decoder produces the right duration
|
||||||
|
tgt = seq_cfg.latent_seq_len
|
||||||
|
if latent.shape[2] < tgt:
|
||||||
|
import torch.nn.functional as F
|
||||||
|
latent = F.pad(latent, (0, tgt - latent.shape[2]))
|
||||||
|
elif latent.shape[2] > tgt:
|
||||||
|
latent = latent[:, :, :tgt]
|
||||||
|
|
||||||
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
|
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
|
||||||
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)
|
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user