fix: add CFG to LoRA mel pre-generation to match inference conditions
Pre-generated mels were using a bare forward pass with no classifier-free guidance, producing mels that don't match what the vocoder sees at inference (where cfg_strength=4.5 is the default). Now uses ode_wrapper with preprocess_conditions/get_empty_conditions, same as the sampler node. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -374,13 +374,14 @@ def _find_audio_for_npz(npz_path: Path):
|
|||||||
|
|
||||||
|
|
||||||
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||||
sample_rate, duration, seed=42, num_steps=25):
|
sample_rate, duration, seed=42, num_steps=25,
|
||||||
|
cfg_strength=4.5):
|
||||||
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||||
|
|
||||||
Uses the LoRA adapter to run full ODE generation → VAE decode → mel for
|
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
||||||
each clip's conditioning features. Returns (lora_mel, clean_audio) pairs
|
mel for each clip's conditioning features. CFG matches the sampler's
|
||||||
that the vocoder trainer can use: vocoder learns to produce clean audio
|
default (4.5) so the degraded mels the vocoder trains on are representative
|
||||||
from LoRA-distorted mels.
|
of what it will see at inference time.
|
||||||
|
|
||||||
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||||
"""
|
"""
|
||||||
@@ -496,13 +497,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
|
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Generate LoRA latent via ODE
|
# Generate LoRA latent via ODE with CFG (matches sampler)
|
||||||
|
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
|
empty_conditions = generator.get_empty_conditions(bs=1)
|
||||||
|
|
||||||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||||
device=device, dtype=dtype, generator=rng)
|
device=device, dtype=dtype, generator=rng)
|
||||||
|
|
||||||
def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip):
|
def velocity_fn(t, x, _cond=conditions, _empty=empty_conditions,
|
||||||
return generator.forward(x, _cf, _sf, _tc,
|
_cfg=cfg_strength):
|
||||||
t.reshape(1).to(device, dtype))
|
return generator.ode_wrapper(t, x, _cond, _empty, _cfg)
|
||||||
|
|
||||||
x1_pred = fm.to_data(velocity_fn, x0)
|
x1_pred = fm.to_data(velocity_fn, x0)
|
||||||
x1_unnorm = generator.unnormalize(x1_pred.clone())
|
x1_unnorm = generator.unnormalize(x1_pred.clone())
|
||||||
|
|||||||
Reference in New Issue
Block a user