diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index a424c98..629d0f4 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -26,10 +26,12 @@ Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader. """ +import copy import random import threading from pathlib import Path +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -354,6 +356,172 @@ def _phase_aware_stft_loss(pred_wav, target_wav, device): return loss / (len(_STFT_RESOLUTIONS) * 3) +# --------------------------------------------------------------------------- +# LoRA mel pre-generation +# --------------------------------------------------------------------------- + +_AUDIO_EXTS = (".wav", ".flac", ".mp3", ".ogg", ".aac") + + +def _find_audio_for_npz(npz_path: Path): + """Find audio file matching an .npz stem (same as LoRA trainer _find_audio).""" + for ext in _AUDIO_EXTS: + c = npz_path.with_suffix(ext) + if c.exists(): + return c + return None + + +def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, + sample_rate, duration, seed=42, num_steps=25): + """Generate LoRA mels for all clips with matching audio in data_dir. + + Uses the LoRA adapter to run full ODE generation → VAE decode → mel for + each clip's conditioning features. Returns (lora_mel, clean_audio) pairs + that the vocoder trainer can use: vocoder learns to produce clean audio + from LoRA-distorted mels. + + Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors. + """ + from selva_core.model.lora import apply_lora, load_lora + from selva_core.model.flow_matching import FlowMatching + + seq_cfg = model["seq_cfg"] + feature_utils = model["feature_utils"] + + # Load LoRA checkpoint + ckpt = torch.load(str(lora_adapter_path), map_location="cpu", weights_only=False) + if isinstance(ckpt, dict) and "state_dict" in ckpt: + state_dict = ckpt["state_dict"] + meta = ckpt.get("meta", {}) + else: + state_dict = ckpt + meta = {} + + rank = int(meta.get("rank", 16)) + alpha = float(meta.get("alpha", float(rank))) + target = list(meta.get("target", ["attn.qkv"])) + use_rslora = meta.get("use_rslora", False) + + # Apply LoRA to a temporary generator copy + generator = copy.deepcopy(model["generator"]).to(device, dtype) + n = apply_lora(generator, rank=rank, alpha=alpha, + target_suffixes=tuple(target), + init_mode="standard", use_rslora=use_rslora) + load_lora(generator, state_dict) + generator.update_seq_lengths( + latent_seq_len=seq_cfg.latent_seq_len, + clip_seq_len=seq_cfg.clip_seq_len, + sync_seq_len=seq_cfg.sync_seq_len, + ) + generator.eval() + print(f"[BigVGAN] LoRA loaded: {Path(lora_adapter_path).name} " + f"(rank={rank}, {n} layers)", flush=True) + + # Load .npz features + matching audio + npz_files = sorted(data_dir.glob("*.npz")) + if not npz_files: + raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — " + "point data_dir to your LoRA training features directory") + + # Load prompt map if available (same logic as LoRA trainer) + prompt_map = {} + prompts_file = data_dir / "prompts.txt" + if prompts_file.exists(): + for line in prompts_file.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "|" in line: + fname, prompt = line.split("|", 1) + prompt_map[fname.strip()] = prompt.strip() + default_prompt = data_dir.name + + fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) + rng = torch.Generator(device=device).manual_seed(seed) + + # Move VAE+vocoder to device for decode + tod = feature_utils.tod + tod_orig_dev = next(tod.parameters()).device + tod.to(device) + + pairs = [] + try: + with torch.no_grad(): + for npz_path in npz_files: + audio_path = _find_audio_for_npz(npz_path) + if audio_path is None: + print(f" [BigVGAN] No audio for {npz_path.name}, skipping", flush=True) + continue + + # Load .npz conditioning features + data = dict(np.load(str(npz_path), allow_pickle=False)) + clip_f = torch.from_numpy(data["clip_features"]).to(device, dtype) + sync_f = torch.from_numpy(data["sync_features"]).to(device, dtype) + + # Pad/trim to expected sequence lengths + c_tgt = seq_cfg.clip_seq_len + if clip_f.shape[1] < c_tgt: + clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1])) + elif clip_f.shape[1] > c_tgt: + clip_f = clip_f[:, :c_tgt, :] + + s_tgt = seq_cfg.sync_seq_len + if sync_f.shape[1] < s_tgt: + sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1])) + elif sync_f.shape[1] > s_tgt: + sync_f = sync_f[:, :s_tgt, :] + + # Text CLIP encoding + prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt)) + if isinstance(prompt, np.ndarray): + prompt = str(prompt) + text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype) + + # Load clean audio + try: + wav, sr = _load_wav(audio_path) + if wav.shape[0] > 1: + wav = wav.mean(0, keepdim=True) + if sr != sample_rate: + wav = torchaudio.functional.resample(wav, sr, sample_rate) + wav = wav.squeeze(0) + target_len = int(duration * sample_rate) + if wav.shape[0] >= target_len: + wav = wav[:target_len] + else: + wav = F.pad(wav, (0, target_len - wav.shape[0])) + except Exception as e: + print(f" [BigVGAN] Failed loading {audio_path.name}: {e}", flush=True) + continue + + # Generate LoRA latent via ODE + x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, + device=device, dtype=dtype, generator=rng) + + def velocity_fn(t, x, _cf=clip_f, _sf=sync_f, _tc=text_clip): + return generator.forward(x, _cf, _sf, _tc, + t.reshape(1).to(device, dtype)) + + x1_pred = fm.to_data(velocity_fn, x0) + x1_unnorm = generator.unnormalize(x1_pred.clone()) + + # VAE decode → mel + mel = feature_utils.decode(x1_unnorm) # [1, n_mels, T_mel] + + pairs.append((mel.squeeze(0).float().cpu(), wav.float().cpu())) + del x0, x1_pred, x1_unnorm, mel + print(f" [BigVGAN] Generated: {npz_path.stem}", flush=True) + + finally: + tod.to(tod_orig_dev) + del generator + soft_empty_cache() + + print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True) + return pairs + + # --------------------------------------------------------------------------- # Node # --------------------------------------------------------------------------- @@ -447,12 +615,22 @@ class SelvaBigvganTrainer: "the key fix for harmonic smearing. Leave empty to use mel+STFT losses only." ), }), + "lora_adapter": ("STRING", { + "default": "", + "tooltip": ( + "Optional path to a LoRA adapter .pt file. When provided, the trainer " + "pre-generates LoRA-distorted mels for each training clip (using the full " + "generation pipeline) and trains the vocoder to produce clean audio from them. " + "data_dir must contain .npz feature files alongside audio files " + "(same directory used for LoRA training)." + ), + }), }, } def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size, segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase, - save_every, seed, discriminator_path=""): + save_every, seed, discriminator_path="", lora_adapter=""): import traceback device = get_device() @@ -491,6 +669,14 @@ class SelvaBigvganTrainer: if not disc_path.exists(): raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}") + lora_path = None + if lora_adapter and lora_adapter.strip(): + lora_path = Path(lora_adapter.strip()) + if not lora_path.is_absolute(): + lora_path = Path(folder_paths.base_path) / lora_path + if not lora_path.exists(): + raise FileNotFoundError(f"[BigVGAN] LoRA adapter not found: {lora_path}") + # Find and pre-load audio clips segment_samples = int(segment_seconds * sample_rate) audio_files = [] @@ -556,6 +742,23 @@ class SelvaBigvganTrainer: def _worker(): try: + # Pre-generate LoRA mels in the worker thread (inference_mode is + # thread-local — off here) so deep-copied generator tensors are clean. + lora_mel_pairs = None + if lora_path is not None: + seq_cfg = model["seq_cfg"] + lora_mel_pairs = _pregenerate_lora_mels( + model, data_dir, str(lora_path), + device, dtype, sample_rate, + seq_cfg.duration, seed=seed, + ) + if not lora_mel_pairs: + raise RuntimeError( + "[BigVGAN] LoRA adapter provided but no mel/audio pairs " + "could be generated. Check that data_dir contains .npz " + "files with matching audio files." + ) + _result[0] = _do_train( vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, @@ -563,6 +766,7 @@ class SelvaBigvganTrainer: train_mode, steps, lr, batch_size, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase, save_every, seed, out_path, disc_path, pbar, + lora_mel_pairs, ) except Exception as e: _exc[0] = e @@ -586,7 +790,8 @@ def _do_train(vocoder, mel_converter, clips, segment_samples, sample_rate, train_mode, steps, lr, batch_size, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase, - save_every, seed, out_path, disc_path, pbar): + save_every, seed, out_path, disc_path, pbar, + lora_mel_pairs=None): """Execute training. Called in a fresh thread — no inference_mode active. Even though inference_mode is off here, tensors created in the calling @@ -761,20 +966,62 @@ def _do_train(vocoder, mel_converter, clips, log_file = open(log_path, "w", buffering=1) # line-buffered log_file.write("step,total_loss,fm_loss,mel_loss,stft_loss,phase_loss,l2sp_loss\n") + # ── Pre-compute mel segment sizes for LoRA mel cropping ─────────────── + # LoRA mels have shape [n_mels, T_mel_full] for the full clip duration. + # We need to crop segment_seconds from both mel and audio at same position. + if lora_mel_pairs: + _example_mel = lora_mel_pairs[0][0] # [n_mels, T_mel_full] + _example_audio = lora_mel_pairs[0][1] # [L] + _mel_frames_full = _example_mel.shape[-1] + _audio_samples_full = _example_audio.shape[0] + # mel frames per audio sample + _mel_per_sample = _mel_frames_full / _audio_samples_full + _mel_segment = int(segment_samples * _mel_per_sample) + print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames " + f"per {segment_samples} audio samples", flush=True) + try: for step in range(steps): - # Sample random batch — clips are CPU floats, move to device - batch = [] - for _ in range(batch_size): - clip = random.choice(clips) - start = random.randint(0, clip.shape[0] - segment_samples) - batch.append(clip[start : start + segment_samples]) + if lora_mel_pairs: + # LoRA mode: sample LoRA mel + matching clean audio from same pair. + # Crop both from the same time position for alignment. + audio_batch = [] + mel_batch = [] + for _ in range(batch_size): + lora_mel, lora_audio = random.choice(lora_mel_pairs) + max_start = lora_audio.shape[0] - segment_samples + if max_start > 0: + audio_start = random.randint(0, max_start) + else: + audio_start = 0 + audio_batch.append(lora_audio[audio_start : audio_start + segment_samples]) + mel_start = int(audio_start * _mel_per_sample) + mel_crop = lora_mel[:, mel_start : mel_start + _mel_segment] + # Pad if crop goes past edge + if mel_crop.shape[-1] < _mel_segment: + mel_crop = F.pad(mel_crop, (0, _mel_segment - mel_crop.shape[-1])) + mel_batch.append(mel_crop) - target_flat = torch.stack(batch).to(device, dtype) # [B, T] - target_wav = target_flat.unsqueeze(1) # [B, 1, T] + target_flat = torch.stack(audio_batch).to(device, dtype) # [B, T] + target_wav = target_flat.unsqueeze(1) # [B, 1, T] + input_mel = torch.stack(mel_batch).to(device, dtype) # [B, n_mels, T_seg] + else: + # Standard mode: sample random crops from clean audio clips + batch = [] + for _ in range(batch_size): + clip = random.choice(clips) + start = random.randint(0, clip.shape[0] - segment_samples) + batch.append(clip[start : start + segment_samples]) + target_flat = torch.stack(batch).to(device, dtype) # [B, T] + target_wav = target_flat.unsqueeze(1) # [B, 1, T] + + with torch.no_grad(): + input_mel = mel_converter(target_flat) # [B, n_mels, T_mel] + + # Clean target mel for mel loss (always from clean audio) with torch.no_grad(): - target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] + target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] # Gradient checkpointing: recompute BigVGAN activations during # backward instead of storing them. The 512x upsampling stack @@ -782,7 +1029,7 @@ def _do_train(vocoder, mel_converter, clips, # ~2x compute for a large reduction in activation memory, allowing # batch_size > 1 without OOM. pred_wav = torch.utils.checkpoint.checkpoint( - vocoder, target_mel, use_reentrant=False + vocoder, input_mel, use_reentrant=False ) # [B, 1, T_wav] T = min(pred_wav.shape[-1], target_wav.shape[-1]) @@ -882,4 +1129,22 @@ def _do_train(vocoder, mel_converter, clips, torch.save(save_dict, str(out_path)) print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True) _save_sample("final") + + # Generate a LoRA mel → vocoder sample so the user can hear the improvement + if lora_mel_pairs: + try: + lora_mel_full = lora_mel_pairs[0][0] # [n_mels, T_mel] + voc_device = next(vocoder.parameters()).device + voc_dtype = next(vocoder.parameters()).dtype + with torch.no_grad(): + wav_lora = vocoder(lora_mel_full.unsqueeze(0).to(voc_device, voc_dtype)) + if wav_lora.dim() == 2: + wav_lora = wav_lora.unsqueeze(1) + wav_lora = wav_lora.float().cpu().clamp(-1, 1) + lora_wav_path = out_path.parent / f"{out_path.stem}_lora_sample.wav" + _save_wav(lora_wav_path, wav_lora.squeeze(0), sample_rate) + print(f"[BigVGAN] LoRA mel sample: {lora_wav_path}", flush=True) + except Exception as e: + print(f"[BigVGAN] LoRA sample failed: {e}", flush=True) + return str(out_path)