fix: pad/trim latent to exact latent_seq_len after VAE encoding
STFT hop-size rounding produces ±1 latent frame vs the expected seq length. Clamp to seq_cfg.latent_seq_len after transpose so generator.forward assertion passes. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -339,7 +339,13 @@ class SelvaLoraTrainer:
|
|||||||
audio_b = audio.unsqueeze(0).to(device)
|
audio_b = audio.unsqueeze(0).to(device)
|
||||||
dist = vae_utils.encode_audio(audio_b)
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||||
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||||
|
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||||
|
tgt = seq_cfg.latent_seq_len
|
||||||
|
if x1.shape[1] < tgt:
|
||||||
|
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||||
|
elif x1.shape[1] > tgt:
|
||||||
|
x1 = x1[:, :tgt, :]
|
||||||
|
|
||||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||||
|
|||||||
@@ -276,6 +276,12 @@ def main():
|
|||||||
try:
|
try:
|
||||||
audio = load_audio(audio_path, sample_rate, duration)
|
audio = load_audio(audio_path, sample_rate, duration)
|
||||||
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
|
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
|
||||||
|
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||||
|
tgt = seq_cfg.latent_seq_len
|
||||||
|
if x1.shape[1] < tgt:
|
||||||
|
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||||
|
elif x1.shape[1] > tgt:
|
||||||
|
x1 = x1[:, :tgt, :]
|
||||||
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
||||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user