fix: pad/trim latent to exact latent_seq_len after VAE encoding
STFT hop-size rounding produces ±1 latent frame vs the expected seq length. Clamp to seq_cfg.latent_seq_len after transpose so generator.forward assertion passes. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -340,6 +340,12 @@ class SelvaLoraTrainer:
|
||||
dist = vae_utils.encode_audio(audio_b)
|
||||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||
tgt = seq_cfg.latent_seq_len
|
||||
if x1.shape[1] < tgt:
|
||||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||
elif x1.shape[1] > tgt:
|
||||
x1 = x1[:, :tgt, :]
|
||||
|
||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||
|
||||
@@ -276,6 +276,12 @@ def main():
|
||||
try:
|
||||
audio = load_audio(audio_path, sample_rate, duration)
|
||||
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
|
||||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||
tgt = seq_cfg.latent_seq_len
|
||||
if x1.shape[1] < tgt:
|
||||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||
elif x1.shape[1] > tgt:
|
||||
x1 = x1[:, :tgt, :]
|
||||
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user