From db112394e8ad132640fa0ae88a3c7237ffea1c9e Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 16:15:14 +0200 Subject: [PATCH] feat: add AF-Vocoder GAFilter to BigVGAN trainer and loader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN residual blocks. Initialized as identity so training starts from pretrained behaviour. - inject_gafilters() walks resblocks.*.activations and wraps each Activation1d with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically - Trained alongside Snake alphas in snake_alpha_only mode - Checkpoint saves has_gafilter + gafilter_kernel_size metadata - Loader detects metadata and injects before load_state_dict so weights populate correctly - Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9) Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_loader.py | 6 ++ nodes/selva_bigvgan_trainer.py | 109 ++++++++++++++++++++++++++++++--- 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/nodes/selva_bigvgan_loader.py b/nodes/selva_bigvgan_loader.py index fbff9ed..d6f3ff7 100644 --- a/nodes/selva_bigvgan_loader.py +++ b/nodes/selva_bigvgan_loader.py @@ -13,6 +13,7 @@ import torch import folder_paths from .utils import SELVA_CATEGORY +from .selva_bigvgan_trainer import inject_gafilters class SelvaBigvganLoader: @@ -60,6 +61,11 @@ class SelvaBigvganLoader: else: raise ValueError(f"[BigVGAN] Unknown mode: {mode}") + if ckpt.get("has_gafilter", False): + kernel_size = ckpt.get("gafilter_kernel_size", 9) + n_gaf = inject_gafilters(vocoder, kernel_size) + print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True) + vocoder.load_state_dict(ckpt["generator"]) vocoder.eval() diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 158c514..3a28bf9 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -161,6 +161,68 @@ def _feature_matching_loss(fmaps_real, fmaps_gen): return loss / len(fmaps_real) +# --------------------------------------------------------------------------- +# AF-Vocoder GAFilter (Interspeech 2025) +# --------------------------------------------------------------------------- + +class GAFilter(nn.Module): + """Learnable per-channel depthwise FIR filter inserted after Snake activations. + + Initialized as identity (delta at center) so training starts from the + pretrained vocoder's behaviour. Learns to shape the per-channel frequency + response to fix harmonic artifacts. + """ + def __init__(self, channels: int, kernel_size: int = 9): + super().__init__() + self.conv = nn.Conv1d( + channels, channels, kernel_size, + padding=kernel_size // 2, groups=channels, bias=False, + ) + nn.init.zeros_(self.conv.weight) + self.conv.weight.data[:, 0, kernel_size // 2] = 1.0 # identity + + def forward(self, x): + return self.conv(x) + + +class _ActivationWithGAFilter(nn.Module): + def __init__(self, activation: nn.Module, gafilter: GAFilter): + super().__init__() + self.activation = activation + self.gafilter = gafilter + + def forward(self, x): + return self.gafilter(self.activation(x)) + + +def inject_gafilters(vocoder: nn.Module, kernel_size: int = 9) -> int: + """Inject GAFilter after each Activation1d in BigVGAN residual blocks. + + Modifies vocoder in-place. GAFilter weights appear in vocoder.state_dict() + under resblocks.{i}.activations.{j}.gafilter.conv.weight — so a normal + load_state_dict call after injection will populate them correctly. + + Returns the number of injected filters. + """ + count = 0 + for resblock in getattr(vocoder, "resblocks", []): + activations = getattr(resblock, "activations", None) + if activations is None: + continue + for j in range(len(activations)): + act1d = activations[j] + act = getattr(act1d, "act", None) + if act is None: + continue + alpha = getattr(act, "alpha", None) + if alpha is None: + continue + channels = alpha.shape[0] + activations[j] = _ActivationWithGAFilter(act1d, GAFilter(channels, kernel_size)) + count += 1 + return count + + # --------------------------------------------------------------------------- # Utility helpers # --------------------------------------------------------------------------- @@ -342,6 +404,18 @@ class SelvaBigvganTrainer: "Increase to 1e-2 for all_params to prevent catastrophic forgetting." ), }), + "use_gafilter": ("BOOLEAN", { + "default": True, + "tooltip": ( + "Inject AF-Vocoder GAFilter (Interspeech 2025) after each Snake activation. " + "Adds a learnable depthwise FIR filter per channel, initialized as identity. " + "Trained alongside Snake alphas. Saved into the checkpoint for inference." + ), + }), + "gafilter_kernel_size": ("INT", { + "default": 9, "min": 3, "max": 31, "step": 2, + "tooltip": "FIR filter length for GAFilter. Must be odd. Larger = wider frequency response control.", + }), "lambda_phase": ("FLOAT", { "default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1, "tooltip": ( @@ -367,7 +441,8 @@ class SelvaBigvganTrainer: } def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size, - segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""): + segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase, + save_every, seed, discriminator_path=""): import traceback device = get_device() @@ -475,7 +550,8 @@ class SelvaBigvganTrainer: vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, - train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase, + train_mode, steps, lr, batch_size, lambda_l2sp, + use_gafilter, gafilter_kernel_size, lambda_phase, save_every, seed, out_path, disc_path, pbar, ) except Exception as e: @@ -498,7 +574,8 @@ class SelvaBigvganTrainer: def _do_train(vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, - train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase, + train_mode, steps, lr, batch_size, lambda_l2sp, + use_gafilter, gafilter_kernel_size, lambda_phase, save_every, seed, out_path, disc_path, pbar): """Execute training. Called in a fresh thread — no inference_mode active. @@ -592,18 +669,25 @@ def _do_train(vocoder, mel_converter, clips, if buf is not None: module._buffers[bname] = buf.clone() + # ── GAFilter injection (after inference-flag stripping) ────────────────── + # GAFilter params are fresh tensors — no inference flag to strip. + if use_gafilter: + n_gaf = inject_gafilters(vocoder, gafilter_kernel_size) + vocoder.to(device) + print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={gafilter_kernel_size}", flush=True) + # ── Training mode: select which parameters to train ────────────────────── if train_mode == "snake_alpha_only": alpha_params = [] for name, param in vocoder.named_parameters(): - if "alpha" in name: + if "alpha" in name or (use_gafilter and "gafilter" in name): param.requires_grad_(True) alpha_params.append(param) else: param.requires_grad_(False) n_trainable = sum(p.numel() for p in alpha_params) print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params " - f"({len(alpha_params)} alpha tensors)", flush=True) + f"({len(alpha_params)} tensors, gafilter={'yes' if use_gafilter else 'no'})", flush=True) trainable_params = alpha_params else: # all_params for param in vocoder.parameters(): @@ -749,7 +833,11 @@ def _do_train(vocoder, mel_converter, clips, if (step + 1) % save_every == 0 and (step + 1) < steps: step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" - torch.save({"generator": vocoder.state_dict()}, str(step_path)) + torch.save({ + "generator": vocoder.state_dict(), + "has_gafilter": use_gafilter, + "gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9, + }, str(step_path)) print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) vocoder.eval() _save_sample(f"step{step+1}") @@ -762,7 +850,12 @@ def _do_train(vocoder, mel_converter, clips, feature_utils.to("cpu") soft_empty_cache() - torch.save({"generator": vocoder.state_dict()}, str(out_path)) - print(f"\n[BigVGAN] Saved: {out_path}", flush=True) + save_dict = { + "generator": vocoder.state_dict(), + "has_gafilter": use_gafilter, + "gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9, + } + torch.save(save_dict, str(out_path)) + print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True) _save_sample("final") return str(out_path)