From bba5aec7a5025e5610ae83b0b8e8c5c5345485d6 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 00:17:16 +0200 Subject: [PATCH] fix: add CFG to LoRA mel pre-generation to match inference conditions Pre-generated mels were using a bare forward pass with no classifier-free guidance, producing mels that don't match what the vocoder sees at inference (where cfg_strength=4.5 is the default). Now uses ode_wrapper with preprocess_conditions/get_empty_conditions, same as the sampler node. Co-Authored-By: Claude Opus 4.6 --- nodes/selva_bigvgan_trainer.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 24d6425..2de8eff 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -374,13 +374,14 @@ def _find_audio_for_npz(npz_path: Path): def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, - sample_rate, duration, seed=42, num_steps=25): + sample_rate, duration, seed=42, num_steps=25, + cfg_strength=4.5): """Generate LoRA mels for all clips with matching audio in data_dir. - Uses the LoRA adapter to run full ODE generation → VAE decode → mel for - each clip's conditioning features. Returns (lora_mel, clean_audio) pairs - that the vocoder trainer can use: vocoder learns to produce clean audio - from LoRA-distorted mels. + Uses the LoRA adapter to run full ODE generation with CFG → VAE decode → + mel for each clip's conditioning features. CFG matches the sampler's + default (4.5) so the degraded mels the vocoder trains on are representative + of what it will see at inference time. Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors. """ @@ -496,13 +497,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True) continue - # Generate LoRA latent via ODE + # Generate LoRA latent via ODE with CFG (matches sampler) + conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip) + empty_conditions = generator.get_empty_conditions(bs=1) + x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, device=device, dtype=dtype, generator=rng) - def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip): - return generator.forward(x, _cf, _sf, _tc, - t.reshape(1).to(device, dtype)) + def velocity_fn(t, x, _cond=conditions, _empty=empty_conditions, + _cfg=cfg_strength): + return generator.ode_wrapper(t, x, _cond, _empty, _cfg) x1_pred = fm.to_data(velocity_fn, x0) x1_unnorm = generator.unnormalize(x1_pred.clone())