fix: add CFG to LoRA mel pre-generation to match inference conditions

Pre-generated mels were using a bare forward pass with no classifier-free
guidance, producing mels that don't match what the vocoder sees at inference
(where cfg_strength=4.5 is the default). Now uses ode_wrapper with
preprocess_conditions/get_empty_conditions, same as the sampler node.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:17:16 +02:00
parent d06936802b
commit bba5aec7a5
+13 -9
View File
@@ -374,13 +374,14 @@ def _find_audio_for_npz(npz_path: Path):
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
sample_rate, duration, seed=42, num_steps=25): sample_rate, duration, seed=42, num_steps=25,
cfg_strength=4.5):
"""Generate LoRA mels for all clips with matching audio in data_dir. """Generate LoRA mels for all clips with matching audio in data_dir.
Uses the LoRA adapter to run full ODE generation → VAE decode → mel for Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
each clip's conditioning features. Returns (lora_mel, clean_audio) pairs mel for each clip's conditioning features. CFG matches the sampler's
that the vocoder trainer can use: vocoder learns to produce clean audio default (4.5) so the degraded mels the vocoder trains on are representative
from LoRA-distorted mels. of what it will see at inference time.
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors. Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
""" """
@@ -496,13 +497,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True) print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
continue continue
# Generate LoRA latent via ODE # Generate LoRA latent via ODE with CFG (matches sampler)
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = generator.get_empty_conditions(bs=1)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng) device=device, dtype=dtype, generator=rng)
def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip): def velocity_fn(t, x, _cond=conditions, _empty=empty_conditions,
return generator.forward(x, _cf, _sf, _tc, _cfg=cfg_strength):
t.reshape(1).to(device, dtype)) return generator.ode_wrapper(t, x, _cond, _empty, _cfg)
x1_pred = fm.to_data(velocity_fn, x0) x1_pred = fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred.clone()) x1_unnorm = generator.unnormalize(x1_pred.clone())