fix: add CFG to LoRA mel pre-generation to match inference conditions

Pre-generated mels were using a bare forward pass with no classifier-free
guidance, producing mels that don't match what the vocoder sees at inference
(where cfg_strength=4.5 is the default). Now uses ode_wrapper with
preprocess_conditions/get_empty_conditions, same as the sampler node.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:17:16 +02:00
parent d06936802b
commit bba5aec7a5
+13 -9
View File
@@ -374,13 +374,14 @@ def _find_audio_for_npz(npz_path: Path):
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
sample_rate, duration, seed=42, num_steps=25):
sample_rate, duration, seed=42, num_steps=25,
cfg_strength=4.5):
"""Generate LoRA mels for all clips with matching audio in data_dir.
Uses the LoRA adapter to run full ODE generation → VAE decode → mel for
each clip's conditioning features. Returns (lora_mel, clean_audio) pairs
that the vocoder trainer can use: vocoder learns to produce clean audio
from LoRA-distorted mels.
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
mel for each clip's conditioning features. CFG matches the sampler's
default (4.5) so the degraded mels the vocoder trains on are representative
of what it will see at inference time.
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
"""
@@ -496,13 +497,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
continue
# Generate LoRA latent via ODE
# Generate LoRA latent via ODE with CFG (matches sampler)
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = generator.get_empty_conditions(bs=1)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng)
def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip):
return generator.forward(x, _cf, _sf, _tc,
t.reshape(1).to(device, dtype))
def velocity_fn(t, x, _cond=conditions, _empty=empty_conditions,
_cfg=cfg_strength):
return generator.ode_wrapper(t, x, _cond, _empty, _cfg)
x1_pred = fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred.clone())