"""SelVA BigVGAN Vocoder Fine-tuner. Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using spectral reconstruction losses. The DiT and VAE are completely untouched. Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1. No GAN discriminator — this is a proof-of-concept to verify that the vocoder can absorb BJ timbral characteristics before investing in full adversarial training. Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader. """ import random import threading from pathlib import Path import torch import torch.nn.functional as F import torchaudio import comfy.utils import folder_paths from .utils import SELVA_CATEGORY, get_device, soft_empty_cache def _load_wav(path): """Load audio file to [channels, samples] float32 tensor. Tries torchaudio first; falls back to soundfile for wav/flac when the ffmpeg/torchcodec backend is unavailable (e.g. libavutil soname mismatch). """ try: return torchaudio.load(str(path)) except Exception: pass # soundfile fallback — handles wav, flac, ogg natively without ffmpeg import soundfile as sf data, sr = sf.read(str(path), dtype="float32", always_2d=True) wav = torch.from_numpy(data.T) # [channels, samples] return wav, sr # Multi-resolution STFT windows — same three resolutions as BigVGAN discriminator config. _STFT_RESOLUTIONS = [ (1024, 120, 600), (2048, 240, 1200), (512, 50, 240), ] def _stft_mag(wav, n_fft, hop_length, win_length, device): """Magnitude STFT. wav: [B, T] → [B, n_fft//2+1, T']""" window = torch.hann_window(win_length, device=device) spec = torch.stft( wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=True, return_complex=True, ) return spec.abs() def _multi_resolution_stft_loss(pred_wav, target_wav, device): """Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]""" pred = pred_wav.squeeze(1) # [B, T] target = target_wav.squeeze(1) loss = torch.zeros(1, device=device) for n_fft, hop, win in _STFT_RESOLUTIONS: pm = _stft_mag(pred, n_fft, hop, win, device) tm = _stft_mag(target, n_fft, hop, win, device) T = min(pm.shape[-1], tm.shape[-1]) loss = loss + F.l1_loss(pm[..., :T], tm[..., :T]) return loss / len(_STFT_RESOLUTIONS) class SelvaBigvganTrainer: OUTPUT_NODE = True CATEGORY = SELVA_CATEGORY FUNCTION = "train" RETURN_TYPES = ("STRING",) RETURN_NAMES = ("checkpoint_path",) OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",) DESCRIPTION = ( "Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using " "spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. " "Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. " "Load the result with SelVA BigVGAN Loader." ) @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("SELVA_MODEL",), "data_dir": ("STRING", { "default": "", "tooltip": "Directory with BJ audio files (.wav/.flac/.mp3). Searched recursively.", }), "output_path": ("STRING", { "default": "bigvgan_bj.pt", "tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.", }), "steps": ("INT", { "default": 2000, "min": 100, "max": 50000, "tooltip": "Training steps. 1000–2000 is a good first experiment.", }), "lr": ("FLOAT", { "default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5, "tooltip": "Learning rate. BigVGAN default is 1e-4.", }), "batch_size": ("INT", {"default": 4, "min": 1, "max": 32}), "segment_seconds": ("FLOAT", { "default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25, "tooltip": "Audio segment length per training sample in seconds.", }), "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), }, } def train(self, model, data_dir, output_path, steps, lr, batch_size, segment_seconds, save_every, seed): import traceback device = get_device() mode = model["mode"] dtype = model["dtype"] feature_utils = model["feature_utils"] mel_converter = feature_utils.mel_converter strategy = model["strategy"] if mode == "16k": vocoder = feature_utils.tod.vocoder.vocoder sample_rate = 16_000 elif mode == "44k": vocoder = feature_utils.tod.vocoder sample_rate = 44_100 else: raise ValueError(f"[BigVGAN] Unknown mode: {mode}") # Resolve paths data_dir = Path(data_dir.strip()) if not data_dir.is_absolute(): data_dir = Path(folder_paths.models_dir) / data_dir if not data_dir.exists(): raise FileNotFoundError(f"[BigVGAN] data_dir not found: {data_dir}") out_path = Path(output_path.strip()) if not out_path.is_absolute(): out_path = Path(folder_paths.get_output_directory()) / out_path out_path.parent.mkdir(parents=True, exist_ok=True) # Find and pre-load audio clips segment_samples = int(segment_seconds * sample_rate) audio_files = [] for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"): audio_files.extend(data_dir.rglob(ext)) if not audio_files: raise FileNotFoundError(f"[BigVGAN] No audio files found in {data_dir}") print(f"[BigVGAN] Loading {len(audio_files)} audio files...", flush=True) clips = [] for af in audio_files: try: wav, sr = _load_wav(af) if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True) if sr != sample_rate: wav = torchaudio.functional.resample(wav, sr, sample_rate) wav = wav.squeeze(0) # [L] if wav.shape[0] >= segment_samples: clips.append(wav.cpu()) else: print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True) except Exception as e: print(f" [BigVGAN] Failed {af.name}: {e}", flush=True) traceback.print_exc() if not clips: raise RuntimeError( f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)" ) print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s " f"steps={steps} lr={lr} batch={batch_size}\n", flush=True) if strategy == "offload_to_cpu": feature_utils.to(device) soft_empty_cache() mel_converter.to(device) pbar = comfy.utils.ProgressBar(steps) # ----------------------------------------------------------------------- # Run the entire training in a fresh thread. # # ComfyUI executes nodes inside torch.inference_mode(). Even with an inner # inference_mode(False) context, factory functions and operations may still # produce inference tensors in some environments (e.g. when the outer # context is set via an async wrapper or a third-party hook). # # torch.inference_mode is THREAD-LOCAL. A new thread always starts with # inference_mode disabled, so all tensor operations in the worker thread # produce normal, autograd-compatible tensors — no flags to fight. # ----------------------------------------------------------------------- _result = [None] _exc = [None] def _worker(): try: _result[0] = _do_train( vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, steps, lr, batch_size, save_every, seed, out_path, pbar, ) except Exception as e: _exc[0] = e traceback.print_exc() t = threading.Thread(target=_worker, daemon=True) t.start() t.join() if _exc[0] is not None: raise _exc[0] return (_result[0],) def _do_train(vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, steps, lr, batch_size, save_every, seed, out_path, pbar): """Execute training. Called in a fresh thread — no inference_mode active.""" import torch.nn as nn_mod torch.manual_seed(seed) random.seed(seed) # Reference segment for eval samples — always clip 0, start 0 ref_wav = clips[0][:segment_samples].to(device, dtype) # [T] ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel] def _save_sample(label): try: voc_device = next(vocoder.parameters()).device mel = ref_mel.to(voc_device) with torch.no_grad(): wav = vocoder(mel) if wav.dim() == 2: wav = wav.unsqueeze(1) wav = wav.float().cpu().clamp(-1, 1) wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate) print(f"[BigVGAN] Sample saved: {wav_path}", flush=True) except Exception as e: print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True) _save_sample("baseline") # Replace inference-tensor parameters with fresh trainable parameters. # (Vocoder was loaded by ComfyUI before this thread existed; its parameters # may carry the inference flag from that context. Here we're in a clean thread # so all new tensors are normal.) for module in vocoder.modules(): for pname, param in list(module._parameters.items()): if param is not None: fresh = param.data.clone().detach().requires_grad_(True) module._parameters[pname] = nn_mod.Parameter(fresh) optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) vocoder.train() try: for step in range(steps): # Sample random batch — clips are CPU floats, move to device batch = [] for _ in range(batch_size): clip = random.choice(clips) start = random.randint(0, clip.shape[0] - segment_samples) batch.append(clip[start : start + segment_samples]) target_flat = torch.stack(batch).to(device, dtype) # [B, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T] with torch.no_grad(): target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] pred_wav = vocoder(target_mel) # [B, 1, T_wav] T = min(pred_wav.shape[-1], target_wav.shape[-1]) pred_t = pred_wav[..., :T] target_t = target_wav[..., :T] pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel'] T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) loss = mel_loss + stft_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0) optimizer.step() pbar.update(1) if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1: print(f"[BigVGAN] {step+1}/{steps} " f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} " f"total={loss.item():.4f}", flush=True) if (step + 1) % save_every == 0 and (step + 1) < steps: step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" torch.save({"generator": vocoder.state_dict()}, str(step_path)) print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) vocoder.eval() _save_sample(f"step{step+1}") vocoder.train() finally: vocoder.requires_grad_(False) vocoder.eval() if strategy == "offload_to_cpu": feature_utils.to("cpu") soft_empty_cache() torch.save({"generator": vocoder.state_dict()}, str(out_path)) print(f"\n[BigVGAN] Saved: {out_path}", flush=True) _save_sample("final") return str(out_path)