fix: transpose VAE latent from [B,C,T] to [B,T,C] before generator
VAE encoder returns channels-first [B, latent_dim, T]; the generator expects time-first [B, T, latent_dim] (same convention as decode which already does .transpose(1,2)). Fixes normalize() size mismatch. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -338,7 +338,8 @@ class SelvaLoraTrainer:
|
|||||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||||
audio_b = audio.unsqueeze(0).to(device)
|
audio_b = audio.unsqueeze(0).to(device)
|
||||||
dist = vae_utils.encode_audio(audio_b)
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
x1 = dist.mode().clone().cpu()
|
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||||
|
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||||
|
|
||||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||||
|
|||||||
+2
-1
@@ -139,7 +139,8 @@ def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> t
|
|||||||
"""
|
"""
|
||||||
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
|
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
|
||||||
dist = feature_utils.encode_audio(audio_b)
|
dist = feature_utils.encode_audio(audio_b)
|
||||||
return dist.mode().clone().cpu() # [1, seq_len, latent_dim]
|
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||||
|
return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user