feat: add LoRA mel pre-generation to BigVGAN vocoder trainer

When a lora_adapter path is provided, the trainer pre-generates
LoRA-distorted mels for each training clip (full ODE generation +
VAE decode) and trains the vocoder to produce clean audio from them.
This teaches the vocoder to compensate for LoRA latent distribution
shift without requiring perfectly aligned training pairs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 23:26:36 +02:00
parent e16480b4c9
commit 48b72c0be0
+277 -12
View File
@@ -26,10 +26,12 @@ Save format: {'generator': vocoder.state_dict()} — same as the original
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
"""
import copy
import random
import threading
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -354,6 +356,172 @@ def _phase_aware_stft_loss(pred_wav, target_wav, device):
return loss / (len(_STFT_RESOLUTIONS) * 3)
# ---------------------------------------------------------------------------
# LoRA mel pre-generation
# ---------------------------------------------------------------------------
_AUDIO_EXTS = (".wav", ".flac", ".mp3", ".ogg", ".aac")
def _find_audio_for_npz(npz_path: Path):
"""Find audio file matching an .npz stem (same as LoRA trainer _find_audio)."""
for ext in _AUDIO_EXTS:
c = npz_path.with_suffix(ext)
if c.exists():
return c
return None
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
sample_rate, duration, seed=42, num_steps=25):
"""Generate LoRA mels for all clips with matching audio in data_dir.
Uses the LoRA adapter to run full ODE generation → VAE decode → mel for
each clip's conditioning features. Returns (lora_mel, clean_audio) pairs
that the vocoder trainer can use: vocoder learns to produce clean audio
from LoRA-distorted mels.
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
"""
from selva_core.model.lora import apply_lora, load_lora
from selva_core.model.flow_matching import FlowMatching
seq_cfg = model["seq_cfg"]
feature_utils = model["feature_utils"]
# Load LoRA checkpoint
ckpt = torch.load(str(lora_adapter_path), map_location="cpu", weights_only=False)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
meta = ckpt.get("meta", {})
else:
state_dict = ckpt
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
use_rslora = meta.get("use_rslora", False)
# Apply LoRA to a temporary generator copy
generator = copy.deepcopy(model["generator"]).to(device, dtype)
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target),
init_mode="standard", use_rslora=use_rslora)
load_lora(generator, state_dict)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
generator.eval()
print(f"[BigVGAN] LoRA loaded: {Path(lora_adapter_path).name} "
f"(rank={rank}, {n} layers)", flush=True)
# Load .npz features + matching audio
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
raise ValueError(f"[BigVGAN] No .npz files in {data_dir}"
"point data_dir to your LoRA training features directory")
# Load prompt map if available (same logic as LoRA trainer)
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
rng = torch.Generator(device=device).manual_seed(seed)
# Move VAE+vocoder to device for decode
tod = feature_utils.tod
tod_orig_dev = next(tod.parameters()).device
tod.to(device)
pairs = []
try:
with torch.no_grad():
for npz_path in npz_files:
audio_path = _find_audio_for_npz(npz_path)
if audio_path is None:
print(f" [BigVGAN] No audio for {npz_path.name}, skipping", flush=True)
continue
# Load .npz conditioning features
data = dict(np.load(str(npz_path), allow_pickle=False))
clip_f = torch.from_numpy(data["clip_features"]).to(device, dtype)
sync_f = torch.from_numpy(data["sync_features"]).to(device, dtype)
# Pad/trim to expected sequence lengths
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
# Text CLIP encoding
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype)
# Load clean audio
try:
wav, sr = _load_wav(audio_path)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0)
target_len = int(duration * sample_rate)
if wav.shape[0] >= target_len:
wav = wav[:target_len]
else:
wav = F.pad(wav, (0, target_len - wav.shape[0]))
except Exception as e:
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
continue
# Generate LoRA latent via ODE
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng)
def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip):
return generator.forward(x, _cf, _sf, _tc,
t.reshape(1).to(device, dtype))
x1_pred = fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred.clone())
# VAE decode → mel
mel = feature_utils.decode(x1_unnorm) # [1, n_mels, T_mel]
pairs.append((mel.squeeze(0).float().cpu(), wav.float().cpu()))
del x0, x1_pred, x1_unnorm, mel
print(f" [BigVGAN] Generated: {npz_path.stem}", flush=True)
finally:
tod.to(tod_orig_dev)
del generator
soft_empty_cache()
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
return pairs
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
@@ -447,12 +615,22 @@ class SelvaBigvganTrainer:
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
),
}),
"lora_adapter": ("STRING", {
"default": "",
"tooltip": (
"Optional path to a LoRA adapter .pt file. When provided, the trainer "
"pre-generates LoRA-distorted mels for each training clip (using the full "
"generation pipeline) and trains the vocoder to produce clean audio from them. "
"data_dir must contain .npz feature files alongside audio files "
"(same directory used for LoRA training)."
),
}),
},
}
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
save_every, seed, discriminator_path=""):
save_every, seed, discriminator_path="", lora_adapter=""):
import traceback
device = get_device()
@@ -491,6 +669,14 @@ class SelvaBigvganTrainer:
if not disc_path.exists():
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
lora_path = None
if lora_adapter and lora_adapter.strip():
lora_path = Path(lora_adapter.strip())
if not lora_path.is_absolute():
lora_path = Path(folder_paths.base_path) / lora_path
if not lora_path.exists():
raise FileNotFoundError(f"[BigVGAN] LoRA adapter not found: {lora_path}")
# Find and pre-load audio clips
segment_samples = int(segment_seconds * sample_rate)
audio_files = []
@@ -556,6 +742,23 @@ class SelvaBigvganTrainer:
def _worker():
try:
# Pre-generate LoRA mels in the worker thread (inference_mode is
# thread-local — off here) so deep-copied generator tensors are clean.
lora_mel_pairs = None
if lora_path is not None:
seq_cfg = model["seq_cfg"]
lora_mel_pairs = _pregenerate_lora_mels(
model, data_dir, str(lora_path),
device, dtype, sample_rate,
seq_cfg.duration, seed=seed,
)
if not lora_mel_pairs:
raise RuntimeError(
"[BigVGAN] LoRA adapter provided but no mel/audio pairs "
"could be generated. Check that data_dir contains .npz "
"files with matching audio files."
)
_result[0] = _do_train(
vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
@@ -563,6 +766,7 @@ class SelvaBigvganTrainer:
train_mode, steps, lr, batch_size, lambda_l2sp,
use_gafilter, gafilter_kernel_size, lambda_phase,
save_every, seed, out_path, disc_path, pbar,
lora_mel_pairs,
)
except Exception as e:
_exc[0] = e
@@ -586,7 +790,8 @@ def _do_train(vocoder, mel_converter, clips,
segment_samples, sample_rate,
train_mode, steps, lr, batch_size, lambda_l2sp,
use_gafilter, gafilter_kernel_size, lambda_phase,
save_every, seed, out_path, disc_path, pbar):
save_every, seed, out_path, disc_path, pbar,
lora_mel_pairs=None):
"""Execute training. Called in a fresh thread — no inference_mode active.
Even though inference_mode is off here, tensors created in the calling
@@ -761,20 +966,62 @@ def _do_train(vocoder, mel_converter, clips,
log_file = open(log_path, "w", buffering=1) # line-buffered
log_file.write("step,total_loss,fm_loss,mel_loss,stft_loss,phase_loss,l2sp_loss\n")
# ── Pre-compute mel segment sizes for LoRA mel cropping ───────────────
# LoRA mels have shape [n_mels, T_mel_full] for the full clip duration.
# We need to crop segment_seconds from both mel and audio at same position.
if lora_mel_pairs:
_example_mel = lora_mel_pairs[0][0] # [n_mels, T_mel_full]
_example_audio = lora_mel_pairs[0][1] # [L]
_mel_frames_full = _example_mel.shape[-1]
_audio_samples_full = _example_audio.shape[0]
# mel frames per audio sample
_mel_per_sample = _mel_frames_full / _audio_samples_full
_mel_segment = int(segment_samples * _mel_per_sample)
print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames "
f"per {segment_samples} audio samples", flush=True)
try:
for step in range(steps):
# Sample random batch — clips are CPU floats, move to device
batch = []
for _ in range(batch_size):
clip = random.choice(clips)
start = random.randint(0, clip.shape[0] - segment_samples)
batch.append(clip[start : start + segment_samples])
if lora_mel_pairs:
# LoRA mode: sample LoRA mel + matching clean audio from same pair.
# Crop both from the same time position for alignment.
audio_batch = []
mel_batch = []
for _ in range(batch_size):
lora_mel, lora_audio = random.choice(lora_mel_pairs)
max_start = lora_audio.shape[0] - segment_samples
if max_start > 0:
audio_start = random.randint(0, max_start)
else:
audio_start = 0
audio_batch.append(lora_audio[audio_start : audio_start + segment_samples])
mel_start = int(audio_start * _mel_per_sample)
mel_crop = lora_mel[:, mel_start : mel_start + _mel_segment]
# Pad if crop goes past edge
if mel_crop.shape[-1] < _mel_segment:
mel_crop = F.pad(mel_crop, (0, _mel_segment - mel_crop.shape[-1]))
mel_batch.append(mel_crop)
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
target_flat = torch.stack(audio_batch).to(device, dtype) # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
input_mel = torch.stack(mel_batch).to(device, dtype) # [B, n_mels, T_seg]
else:
# Standard mode: sample random crops from clean audio clips
batch = []
for _ in range(batch_size):
clip = random.choice(clips)
start = random.randint(0, clip.shape[0] - segment_samples)
batch.append(clip[start : start + segment_samples])
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
with torch.no_grad():
input_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
# Clean target mel for mel loss (always from clean audio)
with torch.no_grad():
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
# Gradient checkpointing: recompute BigVGAN activations during
# backward instead of storing them. The 512x upsampling stack
@@ -782,7 +1029,7 @@ def _do_train(vocoder, mel_converter, clips,
# ~2x compute for a large reduction in activation memory, allowing
# batch_size > 1 without OOM.
pred_wav = torch.utils.checkpoint.checkpoint(
vocoder, target_mel, use_reentrant=False
vocoder, input_mel, use_reentrant=False
) # [B, 1, T_wav]
T = min(pred_wav.shape[-1], target_wav.shape[-1])
@@ -882,4 +1129,22 @@ def _do_train(vocoder, mel_converter, clips,
torch.save(save_dict, str(out_path))
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
_save_sample("final")
# Generate a LoRA mel → vocoder sample so the user can hear the improvement
if lora_mel_pairs:
try:
lora_mel_full = lora_mel_pairs[0][0] # [n_mels, T_mel]
voc_device = next(vocoder.parameters()).device
voc_dtype = next(vocoder.parameters()).dtype
with torch.no_grad():
wav_lora = vocoder(lora_mel_full.unsqueeze(0).to(voc_device, voc_dtype))
if wav_lora.dim() == 2:
wav_lora = wav_lora.unsqueeze(1)
wav_lora = wav_lora.float().cpu().clamp(-1, 1)
lora_wav_path = out_path.parent / f"{out_path.stem}_lora_sample.wav"
_save_wav(lora_wav_path, wav_lora.squeeze(0), sample_rate)
print(f"[BigVGAN] LoRA mel sample: {lora_wav_path}", flush=True)
except Exception as e:
print(f"[BigVGAN] LoRA sample failed: {e}", flush=True)
return str(out_path)