diff --git a/nodes/__init__.py b/nodes/__init__.py index f9be34d..f75434a 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -19,6 +19,8 @@ _NODES = { "SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"), "SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"), "SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"), + "SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"), + "SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_bigvgan_loader.py b/nodes/selva_bigvgan_loader.py new file mode 100644 index 0000000..321b33c --- /dev/null +++ b/nodes/selva_bigvgan_loader.py @@ -0,0 +1,64 @@ +"""SelVA BigVGAN Loader. + +Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer +and replaces the vocoder weights in the loaded SELVA_MODEL in-place. + +The model is modified in-place so ComfyUI's model cache is updated — no need to +reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder. +""" + +from pathlib import Path + +import torch +import folder_paths + +from .utils import SELVA_CATEGORY + + +class SelvaBigvganLoader: + CATEGORY = SELVA_CATEGORY + FUNCTION = "load" + RETURN_TYPES = ("SELVA_MODEL",) + RETURN_NAMES = ("model",) + OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",) + DESCRIPTION = ( + "Loads a fine-tuned BigVGAN vocoder checkpoint from SelVA BigVGAN Trainer " + "and replaces the vocoder weights in the SELVA_MODEL. " + "Connect the output to SelVA Sampler instead of the base model loader." + ) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "path": ("STRING", { + "default": "bigvgan_bj.pt", + "tooltip": "Path to fine-tuned vocoder checkpoint (.pt). " + "Relative paths resolve to ComfyUI output directory.", + }), + }, + } + + def load(self, model, path): + p = Path(path.strip()) + if not p.is_absolute(): + p = Path(folder_paths.get_output_directory()) / p + if not p.exists(): + raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}") + + if model["mode"] != "16k": + raise NotImplementedError( + "[BigVGAN] Fine-tuned loader only supports 16k mode." + ) + + ckpt = torch.load(str(p), map_location="cpu", weights_only=False) + if "generator" not in ckpt: + raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}") + + vocoder = model["feature_utils"].tod.vocoder.vocoder + vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + + print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True) + return (model,) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py new file mode 100644 index 0000000..b562631 --- /dev/null +++ b/nodes/selva_bigvgan_trainer.py @@ -0,0 +1,240 @@ +"""SelVA BigVGAN Vocoder Fine-tuner. + +Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using +spectral reconstruction losses. The DiT and VAE are completely untouched. + +Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1. +No GAN discriminator — this is a proof-of-concept to verify that the vocoder +can absorb BJ timbral characteristics before investing in full adversarial training. + +Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN +checkpoint so it can be loaded with SelVA BigVGAN Loader. +""" + +import random +from pathlib import Path + +import torch +import torch.nn.functional as F +import torchaudio +import comfy.utils +import folder_paths + +from .utils import SELVA_CATEGORY, get_device, soft_empty_cache + +# Multi-resolution STFT windows — same three resolutions as BigVGAN discriminator config. +_STFT_RESOLUTIONS = [ + (1024, 120, 600), + (2048, 240, 1200), + (512, 50, 240), +] + + +def _stft_mag(wav, n_fft, hop_length, win_length, device): + """Magnitude STFT. wav: [B, T] → [B, n_fft//2+1, T']""" + window = torch.hann_window(win_length, device=device) + spec = torch.stft( + wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=True, return_complex=True, + ) + return spec.abs() + + +def _multi_resolution_stft_loss(pred_wav, target_wav, device): + """Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]""" + pred = pred_wav.squeeze(1) # [B, T] + target = target_wav.squeeze(1) + loss = torch.zeros(1, device=device) + for n_fft, hop, win in _STFT_RESOLUTIONS: + pm = _stft_mag(pred, n_fft, hop, win, device) + tm = _stft_mag(target, n_fft, hop, win, device) + T = min(pm.shape[-1], tm.shape[-1]) + loss = loss + F.l1_loss(pm[..., :T], tm[..., :T]) + return loss / len(_STFT_RESOLUTIONS) + + +class SelvaBigvganTrainer: + OUTPUT_NODE = True + CATEGORY = SELVA_CATEGORY + FUNCTION = "train" + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("checkpoint_path",) + OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",) + DESCRIPTION = ( + "Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using " + "spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. " + "16k mode only. Load the result with SelVA BigVGAN Loader." + ) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "data_dir": ("STRING", { + "default": "", + "tooltip": "Directory with BJ audio files (.wav/.flac/.mp3). Searched recursively.", + }), + "output_path": ("STRING", { + "default": "bigvgan_bj.pt", + "tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.", + }), + "steps": ("INT", { + "default": 2000, "min": 100, "max": 50000, + "tooltip": "Training steps. 1000–2000 is a good first experiment.", + }), + "lr": ("FLOAT", { + "default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5, + "tooltip": "Learning rate. BigVGAN default is 1e-4.", + }), + "batch_size": ("INT", {"default": 4, "min": 1, "max": 32}), + "segment_seconds": ("FLOAT", { + "default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25, + "tooltip": "Audio segment length per training sample in seconds.", + }), + "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), + "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), + }, + } + + def train(self, model, data_dir, output_path, steps, lr, batch_size, + segment_seconds, save_every, seed): + import traceback + + if model["mode"] != "16k": + raise NotImplementedError( + "[BigVGAN] Only 16k mode is supported. " + "44k uses BigVGANv2 which requires a different training setup." + ) + + device = get_device() + feature_utils = model["feature_utils"] + mel_converter = feature_utils.mel_converter + sample_rate = 16_000 + strategy = model["strategy"] + + # BigVGANVocoder nn.Module — bypass the @inference_mode wrapper on BigVGAN.forward + vocoder = feature_utils.tod.vocoder.vocoder + + # Resolve paths + data_dir = Path(data_dir.strip()) + if not data_dir.is_absolute(): + data_dir = Path(folder_paths.models_dir) / data_dir + if not data_dir.exists(): + raise FileNotFoundError(f"[BigVGAN] data_dir not found: {data_dir}") + + out_path = Path(output_path.strip()) + if not out_path.is_absolute(): + out_path = Path(folder_paths.get_output_directory()) / out_path + out_path.parent.mkdir(parents=True, exist_ok=True) + + # Find and pre-load audio clips + segment_samples = int(segment_seconds * sample_rate) + audio_files = [] + for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"): + audio_files.extend(data_dir.rglob(ext)) + if not audio_files: + raise FileNotFoundError(f"[BigVGAN] No audio files found in {data_dir}") + + print(f"[BigVGAN] Loading {len(audio_files)} audio files...", flush=True) + clips = [] + for af in audio_files: + try: + wav, sr = torchaudio.load(str(af)) + if wav.shape[0] > 1: + wav = wav.mean(0, keepdim=True) + if sr != sample_rate: + wav = torchaudio.functional.resample(wav, sr, sample_rate) + wav = wav.squeeze(0) # [L] + if wav.shape[0] >= segment_samples: + clips.append(wav) + else: + print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True) + except Exception as e: + print(f" [BigVGAN] Failed {af.name}: {e}", flush=True) + traceback.print_exc() + + if not clips: + raise RuntimeError( + f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)" + ) + print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s " + f"steps={steps} lr={lr} batch={batch_size}\n", flush=True) + + if strategy == "offload_to_cpu": + feature_utils.to(device) + soft_empty_cache() + + mel_converter.to(device) + vocoder.requires_grad_(True) + optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) + + torch.manual_seed(seed) + random.seed(seed) + + pbar = comfy.utils.ProgressBar(steps) + + try: + with torch.inference_mode(False): + with torch.enable_grad(): + vocoder.train() + + for step in range(steps): + # Sample random batch + batch = [] + for _ in range(batch_size): + clip = random.choice(clips) + start = random.randint(0, clip.shape[0] - segment_samples) + batch.append(clip[start : start + segment_samples]) + + target_flat = torch.stack(batch).to(device) # [B, T] + target_wav = target_flat.unsqueeze(1) # [B, 1, T] + + # Fixed target mel (no grad needed here) + with torch.no_grad(): + target_mel = mel_converter(target_flat) # [B, 80, T_mel] + + # Vocoder forward: mel → waveform + pred_wav = vocoder(target_mel) # [B, 1, T_wav] + + # Align lengths + T = min(pred_wav.shape[-1], target_wav.shape[-1]) + pred_t = pred_wav[..., :T] + target_t = target_wav[..., :T] + + # Mel reconstruction loss: mel(pred) vs target_mel + pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel'] + T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) + mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) + + # Multi-resolution STFT loss + stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) + + loss = mel_loss + stft_loss + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0) + optimizer.step() + + pbar.update(1) + + if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1: + print(f"[BigVGAN] {step+1}/{steps} " + f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} " + f"total={loss.item():.4f}", flush=True) + + if (step + 1) % save_every == 0 and (step + 1) < steps: + step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" + torch.save({"generator": vocoder.state_dict()}, str(step_path)) + print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) + + finally: + vocoder.requires_grad_(False) + vocoder.eval() + if strategy == "offload_to_cpu": + feature_utils.to("cpu") + soft_empty_cache() + + torch.save({"generator": vocoder.state_dict()}, str(out_path)) + print(f"\n[BigVGAN] Saved: {out_path}", flush=True) + return (str(out_path),)