feat: add LoRA mel pre-generation to BigVGAN vocoder trainer
When a lora_adapter path is provided, the trainer pre-generates LoRA-distorted mels for each training clip (full ODE generation + VAE decode) and trains the vocoder to produce clean audio from them. This teaches the vocoder to compensate for LoRA latent distribution shift without requiring perfectly aligned training pairs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+277
-12
@@ -26,10 +26,12 @@ Save format: {'generator': vocoder.state_dict()} — same as the original
|
||||
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -354,6 +356,172 @@ def _phase_aware_stft_loss(pred_wav, target_wav, device):
|
||||
return loss / (len(_STFT_RESOLUTIONS) * 3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoRA mel pre-generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_AUDIO_EXTS = (".wav", ".flac", ".mp3", ".ogg", ".aac")
|
||||
|
||||
|
||||
def _find_audio_for_npz(npz_path: Path):
|
||||
"""Find audio file matching an .npz stem (same as LoRA trainer _find_audio)."""
|
||||
for ext in _AUDIO_EXTS:
|
||||
c = npz_path.with_suffix(ext)
|
||||
if c.exists():
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
sample_rate, duration, seed=42, num_steps=25):
|
||||
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||
|
||||
Uses the LoRA adapter to run full ODE generation → VAE decode → mel for
|
||||
each clip's conditioning features. Returns (lora_mel, clean_audio) pairs
|
||||
that the vocoder trainer can use: vocoder learns to produce clean audio
|
||||
from LoRA-distorted mels.
|
||||
|
||||
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||
"""
|
||||
from selva_core.model.lora import apply_lora, load_lora
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
seq_cfg = model["seq_cfg"]
|
||||
feature_utils = model["feature_utils"]
|
||||
|
||||
# Load LoRA checkpoint
|
||||
ckpt = torch.load(str(lora_adapter_path), map_location="cpu", weights_only=False)
|
||||
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
||||
state_dict = ckpt["state_dict"]
|
||||
meta = ckpt.get("meta", {})
|
||||
else:
|
||||
state_dict = ckpt
|
||||
meta = {}
|
||||
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
use_rslora = meta.get("use_rslora", False)
|
||||
|
||||
# Apply LoRA to a temporary generator copy
|
||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||
target_suffixes=tuple(target),
|
||||
init_mode="standard", use_rslora=use_rslora)
|
||||
load_lora(generator, state_dict)
|
||||
generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=seq_cfg.clip_seq_len,
|
||||
sync_seq_len=seq_cfg.sync_seq_len,
|
||||
)
|
||||
generator.eval()
|
||||
print(f"[BigVGAN] LoRA loaded: {Path(lora_adapter_path).name} "
|
||||
f"(rank={rank}, {n} layers)", flush=True)
|
||||
|
||||
# Load .npz features + matching audio
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
if not npz_files:
|
||||
raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — "
|
||||
"point data_dir to your LoRA training features directory")
|
||||
|
||||
# Load prompt map if available (same logic as LoRA trainer)
|
||||
prompt_map = {}
|
||||
prompts_file = data_dir / "prompts.txt"
|
||||
if prompts_file.exists():
|
||||
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "|" in line:
|
||||
fname, prompt = line.split("|", 1)
|
||||
prompt_map[fname.strip()] = prompt.strip()
|
||||
default_prompt = data_dir.name
|
||||
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||
rng = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
# Move VAE+vocoder to device for decode
|
||||
tod = feature_utils.tod
|
||||
tod_orig_dev = next(tod.parameters()).device
|
||||
tod.to(device)
|
||||
|
||||
pairs = []
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for npz_path in npz_files:
|
||||
audio_path = _find_audio_for_npz(npz_path)
|
||||
if audio_path is None:
|
||||
print(f" [BigVGAN] No audio for {npz_path.name}, skipping", flush=True)
|
||||
continue
|
||||
|
||||
# Load .npz conditioning features
|
||||
data = dict(np.load(str(npz_path), allow_pickle=False))
|
||||
clip_f = torch.from_numpy(data["clip_features"]).to(device, dtype)
|
||||
sync_f = torch.from_numpy(data["sync_features"]).to(device, dtype)
|
||||
|
||||
# Pad/trim to expected sequence lengths
|
||||
c_tgt = seq_cfg.clip_seq_len
|
||||
if clip_f.shape[1] < c_tgt:
|
||||
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||
elif clip_f.shape[1] > c_tgt:
|
||||
clip_f = clip_f[:, :c_tgt, :]
|
||||
|
||||
s_tgt = seq_cfg.sync_seq_len
|
||||
if sync_f.shape[1] < s_tgt:
|
||||
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||
elif sync_f.shape[1] > s_tgt:
|
||||
sync_f = sync_f[:, :s_tgt, :]
|
||||
|
||||
# Text CLIP encoding
|
||||
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
|
||||
if isinstance(prompt, np.ndarray):
|
||||
prompt = str(prompt)
|
||||
text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype)
|
||||
|
||||
# Load clean audio
|
||||
try:
|
||||
wav, sr = _load_wav(audio_path)
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(0, keepdim=True)
|
||||
if sr != sample_rate:
|
||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||
wav = wav.squeeze(0)
|
||||
target_len = int(duration * sample_rate)
|
||||
if wav.shape[0] >= target_len:
|
||||
wav = wav[:target_len]
|
||||
else:
|
||||
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
||||
except Exception as e:
|
||||
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
|
||||
continue
|
||||
|
||||
# Generate LoRA latent via ODE
|
||||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||
device=device, dtype=dtype, generator=rng)
|
||||
|
||||
def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip):
|
||||
return generator.forward(x, _cf, _sf, _tc,
|
||||
t.reshape(1).to(device, dtype))
|
||||
|
||||
x1_pred = fm.to_data(velocity_fn, x0)
|
||||
x1_unnorm = generator.unnormalize(x1_pred.clone())
|
||||
|
||||
# VAE decode → mel
|
||||
mel = feature_utils.decode(x1_unnorm) # [1, n_mels, T_mel]
|
||||
|
||||
pairs.append((mel.squeeze(0).float().cpu(), wav.float().cpu()))
|
||||
del x0, x1_pred, x1_unnorm, mel
|
||||
print(f" [BigVGAN] Generated: {npz_path.stem}", flush=True)
|
||||
|
||||
finally:
|
||||
tod.to(tod_orig_dev)
|
||||
del generator
|
||||
soft_empty_cache()
|
||||
|
||||
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
|
||||
return pairs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -447,12 +615,22 @@ class SelvaBigvganTrainer:
|
||||
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
|
||||
),
|
||||
}),
|
||||
"lora_adapter": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": (
|
||||
"Optional path to a LoRA adapter .pt file. When provided, the trainer "
|
||||
"pre-generates LoRA-distorted mels for each training clip (using the full "
|
||||
"generation pipeline) and trains the vocoder to produce clean audio from them. "
|
||||
"data_dir must contain .npz feature files alongside audio files "
|
||||
"(same directory used for LoRA training)."
|
||||
),
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||
segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||
save_every, seed, discriminator_path=""):
|
||||
save_every, seed, discriminator_path="", lora_adapter=""):
|
||||
import traceback
|
||||
|
||||
device = get_device()
|
||||
@@ -491,6 +669,14 @@ class SelvaBigvganTrainer:
|
||||
if not disc_path.exists():
|
||||
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
|
||||
|
||||
lora_path = None
|
||||
if lora_adapter and lora_adapter.strip():
|
||||
lora_path = Path(lora_adapter.strip())
|
||||
if not lora_path.is_absolute():
|
||||
lora_path = Path(folder_paths.base_path) / lora_path
|
||||
if not lora_path.exists():
|
||||
raise FileNotFoundError(f"[BigVGAN] LoRA adapter not found: {lora_path}")
|
||||
|
||||
# Find and pre-load audio clips
|
||||
segment_samples = int(segment_seconds * sample_rate)
|
||||
audio_files = []
|
||||
@@ -556,6 +742,23 @@ class SelvaBigvganTrainer:
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
# Pre-generate LoRA mels in the worker thread (inference_mode is
|
||||
# thread-local — off here) so deep-copied generator tensors are clean.
|
||||
lora_mel_pairs = None
|
||||
if lora_path is not None:
|
||||
seq_cfg = model["seq_cfg"]
|
||||
lora_mel_pairs = _pregenerate_lora_mels(
|
||||
model, data_dir, str(lora_path),
|
||||
device, dtype, sample_rate,
|
||||
seq_cfg.duration, seed=seed,
|
||||
)
|
||||
if not lora_mel_pairs:
|
||||
raise RuntimeError(
|
||||
"[BigVGAN] LoRA adapter provided but no mel/audio pairs "
|
||||
"could be generated. Check that data_dir contains .npz "
|
||||
"files with matching audio files."
|
||||
)
|
||||
|
||||
_result[0] = _do_train(
|
||||
vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
@@ -563,6 +766,7 @@ class SelvaBigvganTrainer:
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||
save_every, seed, out_path, disc_path, pbar,
|
||||
lora_mel_pairs,
|
||||
)
|
||||
except Exception as e:
|
||||
_exc[0] = e
|
||||
@@ -586,7 +790,8 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
segment_samples, sample_rate,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||
save_every, seed, out_path, disc_path, pbar):
|
||||
save_every, seed, out_path, disc_path, pbar,
|
||||
lora_mel_pairs=None):
|
||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||
|
||||
Even though inference_mode is off here, tensors created in the calling
|
||||
@@ -761,20 +966,62 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
log_file = open(log_path, "w", buffering=1) # line-buffered
|
||||
log_file.write("step,total_loss,fm_loss,mel_loss,stft_loss,phase_loss,l2sp_loss\n")
|
||||
|
||||
# ── Pre-compute mel segment sizes for LoRA mel cropping ───────────────
|
||||
# LoRA mels have shape [n_mels, T_mel_full] for the full clip duration.
|
||||
# We need to crop segment_seconds from both mel and audio at same position.
|
||||
if lora_mel_pairs:
|
||||
_example_mel = lora_mel_pairs[0][0] # [n_mels, T_mel_full]
|
||||
_example_audio = lora_mel_pairs[0][1] # [L]
|
||||
_mel_frames_full = _example_mel.shape[-1]
|
||||
_audio_samples_full = _example_audio.shape[0]
|
||||
# mel frames per audio sample
|
||||
_mel_per_sample = _mel_frames_full / _audio_samples_full
|
||||
_mel_segment = int(segment_samples * _mel_per_sample)
|
||||
print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames "
|
||||
f"per {segment_samples} audio samples", flush=True)
|
||||
|
||||
try:
|
||||
for step in range(steps):
|
||||
# Sample random batch — clips are CPU floats, move to device
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
clip = random.choice(clips)
|
||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||
batch.append(clip[start : start + segment_samples])
|
||||
if lora_mel_pairs:
|
||||
# LoRA mode: sample LoRA mel + matching clean audio from same pair.
|
||||
# Crop both from the same time position for alignment.
|
||||
audio_batch = []
|
||||
mel_batch = []
|
||||
for _ in range(batch_size):
|
||||
lora_mel, lora_audio = random.choice(lora_mel_pairs)
|
||||
max_start = lora_audio.shape[0] - segment_samples
|
||||
if max_start > 0:
|
||||
audio_start = random.randint(0, max_start)
|
||||
else:
|
||||
audio_start = 0
|
||||
audio_batch.append(lora_audio[audio_start : audio_start + segment_samples])
|
||||
mel_start = int(audio_start * _mel_per_sample)
|
||||
mel_crop = lora_mel[:, mel_start : mel_start + _mel_segment]
|
||||
# Pad if crop goes past edge
|
||||
if mel_crop.shape[-1] < _mel_segment:
|
||||
mel_crop = F.pad(mel_crop, (0, _mel_segment - mel_crop.shape[-1]))
|
||||
mel_batch.append(mel_crop)
|
||||
|
||||
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
target_flat = torch.stack(audio_batch).to(device, dtype) # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
input_mel = torch.stack(mel_batch).to(device, dtype) # [B, n_mels, T_seg]
|
||||
else:
|
||||
# Standard mode: sample random crops from clean audio clips
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
clip = random.choice(clips)
|
||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||
batch.append(clip[start : start + segment_samples])
|
||||
|
||||
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
with torch.no_grad():
|
||||
input_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
|
||||
# Clean target mel for mel loss (always from clean audio)
|
||||
with torch.no_grad():
|
||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
|
||||
# Gradient checkpointing: recompute BigVGAN activations during
|
||||
# backward instead of storing them. The 512x upsampling stack
|
||||
@@ -782,7 +1029,7 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
# ~2x compute for a large reduction in activation memory, allowing
|
||||
# batch_size > 1 without OOM.
|
||||
pred_wav = torch.utils.checkpoint.checkpoint(
|
||||
vocoder, target_mel, use_reentrant=False
|
||||
vocoder, input_mel, use_reentrant=False
|
||||
) # [B, 1, T_wav]
|
||||
|
||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||
@@ -882,4 +1129,22 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
torch.save(save_dict, str(out_path))
|
||||
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
|
||||
_save_sample("final")
|
||||
|
||||
# Generate a LoRA mel → vocoder sample so the user can hear the improvement
|
||||
if lora_mel_pairs:
|
||||
try:
|
||||
lora_mel_full = lora_mel_pairs[0][0] # [n_mels, T_mel]
|
||||
voc_device = next(vocoder.parameters()).device
|
||||
voc_dtype = next(vocoder.parameters()).dtype
|
||||
with torch.no_grad():
|
||||
wav_lora = vocoder(lora_mel_full.unsqueeze(0).to(voc_device, voc_dtype))
|
||||
if wav_lora.dim() == 2:
|
||||
wav_lora = wav_lora.unsqueeze(1)
|
||||
wav_lora = wav_lora.float().cpu().clamp(-1, 1)
|
||||
lora_wav_path = out_path.parent / f"{out_path.stem}_lora_sample.wav"
|
||||
_save_wav(lora_wav_path, wav_lora.squeeze(0), sample_rate)
|
||||
print(f"[BigVGAN] LoRA mel sample: {lora_wav_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[BigVGAN] LoRA sample failed: {e}", flush=True)
|
||||
|
||||
return str(out_path)
|
||||
|
||||
Reference in New Issue
Block a user