feat: add LoRA mel pre-generation to BigVGAN vocoder trainer
When a lora_adapter path is provided, the trainer pre-generates LoRA-distorted mels for each training clip (full ODE generation + VAE decode) and trains the vocoder to produce clean audio from them. This teaches the vocoder to compensate for LoRA latent distribution shift without requiring perfectly aligned training pairs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+277
-12
@@ -26,10 +26,12 @@ Save format: {'generator': vocoder.state_dict()} — same as the original
|
|||||||
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -354,6 +356,172 @@ def _phase_aware_stft_loss(pred_wav, target_wav, device):
|
|||||||
return loss / (len(_STFT_RESOLUTIONS) * 3)
|
return loss / (len(_STFT_RESOLUTIONS) * 3)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoRA mel pre-generation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_AUDIO_EXTS = (".wav", ".flac", ".mp3", ".ogg", ".aac")
|
||||||
|
|
||||||
|
|
||||||
|
def _find_audio_for_npz(npz_path: Path):
|
||||||
|
"""Find audio file matching an .npz stem (same as LoRA trainer _find_audio)."""
|
||||||
|
for ext in _AUDIO_EXTS:
|
||||||
|
c = npz_path.with_suffix(ext)
|
||||||
|
if c.exists():
|
||||||
|
return c
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||||
|
sample_rate, duration, seed=42, num_steps=25):
|
||||||
|
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||||
|
|
||||||
|
Uses the LoRA adapter to run full ODE generation → VAE decode → mel for
|
||||||
|
each clip's conditioning features. Returns (lora_mel, clean_audio) pairs
|
||||||
|
that the vocoder trainer can use: vocoder learns to produce clean audio
|
||||||
|
from LoRA-distorted mels.
|
||||||
|
|
||||||
|
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||||
|
"""
|
||||||
|
from selva_core.model.lora import apply_lora, load_lora
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
|
||||||
|
# Load LoRA checkpoint
|
||||||
|
ckpt = torch.load(str(lora_adapter_path), map_location="cpu", weights_only=False)
|
||||||
|
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
||||||
|
state_dict = ckpt["state_dict"]
|
||||||
|
meta = ckpt.get("meta", {})
|
||||||
|
else:
|
||||||
|
state_dict = ckpt
|
||||||
|
meta = {}
|
||||||
|
|
||||||
|
rank = int(meta.get("rank", 16))
|
||||||
|
alpha = float(meta.get("alpha", float(rank)))
|
||||||
|
target = list(meta.get("target", ["attn.qkv"]))
|
||||||
|
use_rslora = meta.get("use_rslora", False)
|
||||||
|
|
||||||
|
# Apply LoRA to a temporary generator copy
|
||||||
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||||
|
target_suffixes=tuple(target),
|
||||||
|
init_mode="standard", use_rslora=use_rslora)
|
||||||
|
load_lora(generator, state_dict)
|
||||||
|
generator.update_seq_lengths(
|
||||||
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
|
clip_seq_len=seq_cfg.clip_seq_len,
|
||||||
|
sync_seq_len=seq_cfg.sync_seq_len,
|
||||||
|
)
|
||||||
|
generator.eval()
|
||||||
|
print(f"[BigVGAN] LoRA loaded: {Path(lora_adapter_path).name} "
|
||||||
|
f"(rank={rank}, {n} layers)", flush=True)
|
||||||
|
|
||||||
|
# Load .npz features + matching audio
|
||||||
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
|
if not npz_files:
|
||||||
|
raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — "
|
||||||
|
"point data_dir to your LoRA training features directory")
|
||||||
|
|
||||||
|
# Load prompt map if available (same logic as LoRA trainer)
|
||||||
|
prompt_map = {}
|
||||||
|
prompts_file = data_dir / "prompts.txt"
|
||||||
|
if prompts_file.exists():
|
||||||
|
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if "|" in line:
|
||||||
|
fname, prompt = line.split("|", 1)
|
||||||
|
prompt_map[fname.strip()] = prompt.strip()
|
||||||
|
default_prompt = data_dir.name
|
||||||
|
|
||||||
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||||
|
rng = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
# Move VAE+vocoder to device for decode
|
||||||
|
tod = feature_utils.tod
|
||||||
|
tod_orig_dev = next(tod.parameters()).device
|
||||||
|
tod.to(device)
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
for npz_path in npz_files:
|
||||||
|
audio_path = _find_audio_for_npz(npz_path)
|
||||||
|
if audio_path is None:
|
||||||
|
print(f" [BigVGAN] No audio for {npz_path.name}, skipping", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load .npz conditioning features
|
||||||
|
data = dict(np.load(str(npz_path), allow_pickle=False))
|
||||||
|
clip_f = torch.from_numpy(data["clip_features"]).to(device, dtype)
|
||||||
|
sync_f = torch.from_numpy(data["sync_features"]).to(device, dtype)
|
||||||
|
|
||||||
|
# Pad/trim to expected sequence lengths
|
||||||
|
c_tgt = seq_cfg.clip_seq_len
|
||||||
|
if clip_f.shape[1] < c_tgt:
|
||||||
|
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||||
|
elif clip_f.shape[1] > c_tgt:
|
||||||
|
clip_f = clip_f[:, :c_tgt, :]
|
||||||
|
|
||||||
|
s_tgt = seq_cfg.sync_seq_len
|
||||||
|
if sync_f.shape[1] < s_tgt:
|
||||||
|
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||||
|
elif sync_f.shape[1] > s_tgt:
|
||||||
|
sync_f = sync_f[:, :s_tgt, :]
|
||||||
|
|
||||||
|
# Text CLIP encoding
|
||||||
|
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
|
||||||
|
if isinstance(prompt, np.ndarray):
|
||||||
|
prompt = str(prompt)
|
||||||
|
text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype)
|
||||||
|
|
||||||
|
# Load clean audio
|
||||||
|
try:
|
||||||
|
wav, sr = _load_wav(audio_path)
|
||||||
|
if wav.shape[0] > 1:
|
||||||
|
wav = wav.mean(0, keepdim=True)
|
||||||
|
if sr != sample_rate:
|
||||||
|
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||||
|
wav = wav.squeeze(0)
|
||||||
|
target_len = int(duration * sample_rate)
|
||||||
|
if wav.shape[0] >= target_len:
|
||||||
|
wav = wav[:target_len]
|
||||||
|
else:
|
||||||
|
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Generate LoRA latent via ODE
|
||||||
|
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||||
|
device=device, dtype=dtype, generator=rng)
|
||||||
|
|
||||||
|
def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip):
|
||||||
|
return generator.forward(x, _cf, _sf, _tc,
|
||||||
|
t.reshape(1).to(device, dtype))
|
||||||
|
|
||||||
|
x1_pred = fm.to_data(velocity_fn, x0)
|
||||||
|
x1_unnorm = generator.unnormalize(x1_pred.clone())
|
||||||
|
|
||||||
|
# VAE decode → mel
|
||||||
|
mel = feature_utils.decode(x1_unnorm) # [1, n_mels, T_mel]
|
||||||
|
|
||||||
|
pairs.append((mel.squeeze(0).float().cpu(), wav.float().cpu()))
|
||||||
|
del x0, x1_pred, x1_unnorm, mel
|
||||||
|
print(f" [BigVGAN] Generated: {npz_path.stem}", flush=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
tod.to(tod_orig_dev)
|
||||||
|
del generator
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Node
|
# Node
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -447,12 +615,22 @@ class SelvaBigvganTrainer:
|
|||||||
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
|
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
|
||||||
),
|
),
|
||||||
}),
|
}),
|
||||||
|
"lora_adapter": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": (
|
||||||
|
"Optional path to a LoRA adapter .pt file. When provided, the trainer "
|
||||||
|
"pre-generates LoRA-distorted mels for each training clip (using the full "
|
||||||
|
"generation pipeline) and trains the vocoder to produce clean audio from them. "
|
||||||
|
"data_dir must contain .npz feature files alongside audio files "
|
||||||
|
"(same directory used for LoRA training)."
|
||||||
|
),
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||||
segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
|
segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||||
save_every, seed, discriminator_path=""):
|
save_every, seed, discriminator_path="", lora_adapter=""):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
@@ -491,6 +669,14 @@ class SelvaBigvganTrainer:
|
|||||||
if not disc_path.exists():
|
if not disc_path.exists():
|
||||||
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
|
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
|
||||||
|
|
||||||
|
lora_path = None
|
||||||
|
if lora_adapter and lora_adapter.strip():
|
||||||
|
lora_path = Path(lora_adapter.strip())
|
||||||
|
if not lora_path.is_absolute():
|
||||||
|
lora_path = Path(folder_paths.base_path) / lora_path
|
||||||
|
if not lora_path.exists():
|
||||||
|
raise FileNotFoundError(f"[BigVGAN] LoRA adapter not found: {lora_path}")
|
||||||
|
|
||||||
# Find and pre-load audio clips
|
# Find and pre-load audio clips
|
||||||
segment_samples = int(segment_seconds * sample_rate)
|
segment_samples = int(segment_seconds * sample_rate)
|
||||||
audio_files = []
|
audio_files = []
|
||||||
@@ -556,6 +742,23 @@ class SelvaBigvganTrainer:
|
|||||||
|
|
||||||
def _worker():
|
def _worker():
|
||||||
try:
|
try:
|
||||||
|
# Pre-generate LoRA mels in the worker thread (inference_mode is
|
||||||
|
# thread-local — off here) so deep-copied generator tensors are clean.
|
||||||
|
lora_mel_pairs = None
|
||||||
|
if lora_path is not None:
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
lora_mel_pairs = _pregenerate_lora_mels(
|
||||||
|
model, data_dir, str(lora_path),
|
||||||
|
device, dtype, sample_rate,
|
||||||
|
seq_cfg.duration, seed=seed,
|
||||||
|
)
|
||||||
|
if not lora_mel_pairs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"[BigVGAN] LoRA adapter provided but no mel/audio pairs "
|
||||||
|
"could be generated. Check that data_dir contains .npz "
|
||||||
|
"files with matching audio files."
|
||||||
|
)
|
||||||
|
|
||||||
_result[0] = _do_train(
|
_result[0] = _do_train(
|
||||||
vocoder, mel_converter, clips,
|
vocoder, mel_converter, clips,
|
||||||
device, dtype, strategy, feature_utils,
|
device, dtype, strategy, feature_utils,
|
||||||
@@ -563,6 +766,7 @@ class SelvaBigvganTrainer:
|
|||||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||||
use_gafilter, gafilter_kernel_size, lambda_phase,
|
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||||
save_every, seed, out_path, disc_path, pbar,
|
save_every, seed, out_path, disc_path, pbar,
|
||||||
|
lora_mel_pairs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_exc[0] = e
|
_exc[0] = e
|
||||||
@@ -586,7 +790,8 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
segment_samples, sample_rate,
|
segment_samples, sample_rate,
|
||||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||||
use_gafilter, gafilter_kernel_size, lambda_phase,
|
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||||
save_every, seed, out_path, disc_path, pbar):
|
save_every, seed, out_path, disc_path, pbar,
|
||||||
|
lora_mel_pairs=None):
|
||||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||||
|
|
||||||
Even though inference_mode is off here, tensors created in the calling
|
Even though inference_mode is off here, tensors created in the calling
|
||||||
@@ -761,20 +966,62 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
log_file = open(log_path, "w", buffering=1) # line-buffered
|
log_file = open(log_path, "w", buffering=1) # line-buffered
|
||||||
log_file.write("step,total_loss,fm_loss,mel_loss,stft_loss,phase_loss,l2sp_loss\n")
|
log_file.write("step,total_loss,fm_loss,mel_loss,stft_loss,phase_loss,l2sp_loss\n")
|
||||||
|
|
||||||
|
# ── Pre-compute mel segment sizes for LoRA mel cropping ───────────────
|
||||||
|
# LoRA mels have shape [n_mels, T_mel_full] for the full clip duration.
|
||||||
|
# We need to crop segment_seconds from both mel and audio at same position.
|
||||||
|
if lora_mel_pairs:
|
||||||
|
_example_mel = lora_mel_pairs[0][0] # [n_mels, T_mel_full]
|
||||||
|
_example_audio = lora_mel_pairs[0][1] # [L]
|
||||||
|
_mel_frames_full = _example_mel.shape[-1]
|
||||||
|
_audio_samples_full = _example_audio.shape[0]
|
||||||
|
# mel frames per audio sample
|
||||||
|
_mel_per_sample = _mel_frames_full / _audio_samples_full
|
||||||
|
_mel_segment = int(segment_samples * _mel_per_sample)
|
||||||
|
print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames "
|
||||||
|
f"per {segment_samples} audio samples", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
# Sample random batch — clips are CPU floats, move to device
|
if lora_mel_pairs:
|
||||||
batch = []
|
# LoRA mode: sample LoRA mel + matching clean audio from same pair.
|
||||||
for _ in range(batch_size):
|
# Crop both from the same time position for alignment.
|
||||||
clip = random.choice(clips)
|
audio_batch = []
|
||||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
mel_batch = []
|
||||||
batch.append(clip[start : start + segment_samples])
|
for _ in range(batch_size):
|
||||||
|
lora_mel, lora_audio = random.choice(lora_mel_pairs)
|
||||||
|
max_start = lora_audio.shape[0] - segment_samples
|
||||||
|
if max_start > 0:
|
||||||
|
audio_start = random.randint(0, max_start)
|
||||||
|
else:
|
||||||
|
audio_start = 0
|
||||||
|
audio_batch.append(lora_audio[audio_start : audio_start + segment_samples])
|
||||||
|
mel_start = int(audio_start * _mel_per_sample)
|
||||||
|
mel_crop = lora_mel[:, mel_start : mel_start + _mel_segment]
|
||||||
|
# Pad if crop goes past edge
|
||||||
|
if mel_crop.shape[-1] < _mel_segment:
|
||||||
|
mel_crop = F.pad(mel_crop, (0, _mel_segment - mel_crop.shape[-1]))
|
||||||
|
mel_batch.append(mel_crop)
|
||||||
|
|
||||||
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
target_flat = torch.stack(audio_batch).to(device, dtype) # [B, T]
|
||||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
input_mel = torch.stack(mel_batch).to(device, dtype) # [B, n_mels, T_seg]
|
||||||
|
else:
|
||||||
|
# Standard mode: sample random crops from clean audio clips
|
||||||
|
batch = []
|
||||||
|
for _ in range(batch_size):
|
||||||
|
clip = random.choice(clips)
|
||||||
|
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||||
|
batch.append(clip[start : start + segment_samples])
|
||||||
|
|
||||||
|
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||||
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
input_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||||
|
|
||||||
|
# Clean target mel for mel loss (always from clean audio)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||||
|
|
||||||
# Gradient checkpointing: recompute BigVGAN activations during
|
# Gradient checkpointing: recompute BigVGAN activations during
|
||||||
# backward instead of storing them. The 512x upsampling stack
|
# backward instead of storing them. The 512x upsampling stack
|
||||||
@@ -782,7 +1029,7 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
# ~2x compute for a large reduction in activation memory, allowing
|
# ~2x compute for a large reduction in activation memory, allowing
|
||||||
# batch_size > 1 without OOM.
|
# batch_size > 1 without OOM.
|
||||||
pred_wav = torch.utils.checkpoint.checkpoint(
|
pred_wav = torch.utils.checkpoint.checkpoint(
|
||||||
vocoder, target_mel, use_reentrant=False
|
vocoder, input_mel, use_reentrant=False
|
||||||
) # [B, 1, T_wav]
|
) # [B, 1, T_wav]
|
||||||
|
|
||||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||||
@@ -882,4 +1129,22 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
torch.save(save_dict, str(out_path))
|
torch.save(save_dict, str(out_path))
|
||||||
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
|
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
|
||||||
_save_sample("final")
|
_save_sample("final")
|
||||||
|
|
||||||
|
# Generate a LoRA mel → vocoder sample so the user can hear the improvement
|
||||||
|
if lora_mel_pairs:
|
||||||
|
try:
|
||||||
|
lora_mel_full = lora_mel_pairs[0][0] # [n_mels, T_mel]
|
||||||
|
voc_device = next(vocoder.parameters()).device
|
||||||
|
voc_dtype = next(vocoder.parameters()).dtype
|
||||||
|
with torch.no_grad():
|
||||||
|
wav_lora = vocoder(lora_mel_full.unsqueeze(0).to(voc_device, voc_dtype))
|
||||||
|
if wav_lora.dim() == 2:
|
||||||
|
wav_lora = wav_lora.unsqueeze(1)
|
||||||
|
wav_lora = wav_lora.float().cpu().clamp(-1, 1)
|
||||||
|
lora_wav_path = out_path.parent / f"{out_path.stem}_lora_sample.wav"
|
||||||
|
_save_wav(lora_wav_path, wav_lora.squeeze(0), sample_rate)
|
||||||
|
print(f"[BigVGAN] LoRA mel sample: {lora_wav_path}", flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[BigVGAN] LoRA sample failed: {e}", flush=True)
|
||||||
|
|
||||||
return str(out_path)
|
return str(out_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user