diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index f2cd95b..e18835b 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -338,7 +338,8 @@ class SelvaLoraTrainer: # encode_audio is @inference_mode — .clone() exits inference mode audio_b = audio.unsqueeze(0).to(device) dist = vae_utils.encode_audio(audio_b) - x1 = dist.mode().clone().cpu() + # VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim] + x1 = dist.mode().clone().transpose(1, 2).cpu() # Text → CLIP features (reuse already-loaded CLIP from inference model) text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu() diff --git a/train_lora.py b/train_lora.py index 8942725..8ce1856 100644 --- a/train_lora.py +++ b/train_lora.py @@ -139,7 +139,8 @@ def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> t """ audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L] dist = feature_utils.encode_audio(audio_b) - return dist.mode().clone().cpu() # [1, seq_len, latent_dim] + # VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim] + return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim] # ---------------------------------------------------------------------------