From ad5743280348d7f3fe61c7f712f477408699b34f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 22:12:20 +0200 Subject: [PATCH] fix: pad/trim latent to exact latent_seq_len after VAE encoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit STFT hop-size rounding produces ±1 latent frame vs the expected seq length. Clamp to seq_cfg.latent_seq_len after transpose so generator.forward assertion passes. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_trainer.py | 8 +++++++- train_lora.py | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index e18835b..30e1fe1 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -339,7 +339,13 @@ class SelvaLoraTrainer: audio_b = audio.unsqueeze(0).to(device) dist = vae_utils.encode_audio(audio_b) # VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim] - x1 = dist.mode().clone().transpose(1, 2).cpu() + x1 = dist.mode().clone().transpose(1, 2).cpu() + # STFT rounding can produce ±1 frame — pad or trim to exact seq length + tgt = seq_cfg.latent_seq_len + if x1.shape[1] < tgt: + x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1])) + elif x1.shape[1] > tgt: + x1 = x1[:, :tgt, :] # Text → CLIP features (reuse already-loaded CLIP from inference model) text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu() diff --git a/train_lora.py b/train_lora.py index 8ce1856..b15a0fb 100644 --- a/train_lora.py +++ b/train_lora.py @@ -276,6 +276,12 @@ def main(): try: audio = load_audio(audio_path, sample_rate, duration) x1 = extract_audio_latent(audio, feature_utils, device, dtype) + # STFT rounding can produce ±1 frame — pad or trim to exact seq length + tgt = seq_cfg.latent_seq_len + if x1.shape[1] < tgt: + x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1])) + elif x1.shape[1] > tgt: + x1 = x1[:, :tgt, :] text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu() dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) except Exception as e: