fix: pad/trim latent to exact latent_seq_len after VAE encoding

STFT hop-size rounding produces ±1 latent frame vs the expected seq length.
Clamp to seq_cfg.latent_seq_len after transpose so generator.forward assertion passes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 22:12:20 +02:00
parent 43f732f904
commit ad57432803
2 changed files with 13 additions and 1 deletions
+7 -1
View File
@@ -339,7 +339,13 @@ class SelvaLoraTrainer:
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
x1 = dist.mode().clone().transpose(1, 2).cpu()
x1 = dist.mode().clone().transpose(1, 2).cpu()
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()