feat: add AF-Vocoder GAFilter to BigVGAN trainer and loader

Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel
depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN
residual blocks. Initialized as identity so training starts from pretrained
behaviour.

- inject_gafilters() walks resblocks.*.activations and wraps each Activation1d
  with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically
- Trained alongside Snake alphas in snake_alpha_only mode
- Checkpoint saves has_gafilter + gafilter_kernel_size metadata
- Loader detects metadata and injects before load_state_dict so weights populate correctly
- Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 16:15:14 +02:00
parent c53ea5517c
commit db112394e8
2 changed files with 107 additions and 8 deletions
+6
View File
@@ -13,6 +13,7 @@ import torch
import folder_paths import folder_paths
from .utils import SELVA_CATEGORY from .utils import SELVA_CATEGORY
from .selva_bigvgan_trainer import inject_gafilters
class SelvaBigvganLoader: class SelvaBigvganLoader:
@@ -60,6 +61,11 @@ class SelvaBigvganLoader:
else: else:
raise ValueError(f"[BigVGAN] Unknown mode: {mode}") raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
if ckpt.get("has_gafilter", False):
kernel_size = ckpt.get("gafilter_kernel_size", 9)
n_gaf = inject_gafilters(vocoder, kernel_size)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
vocoder.load_state_dict(ckpt["generator"]) vocoder.load_state_dict(ckpt["generator"])
vocoder.eval() vocoder.eval()
+101 -8
View File
@@ -161,6 +161,68 @@ def _feature_matching_loss(fmaps_real, fmaps_gen):
return loss / len(fmaps_real) return loss / len(fmaps_real)
# ---------------------------------------------------------------------------
# AF-Vocoder GAFilter (Interspeech 2025)
# ---------------------------------------------------------------------------
class GAFilter(nn.Module):
"""Learnable per-channel depthwise FIR filter inserted after Snake activations.
Initialized as identity (delta at center) so training starts from the
pretrained vocoder's behaviour. Learns to shape the per-channel frequency
response to fix harmonic artifacts.
"""
def __init__(self, channels: int, kernel_size: int = 9):
super().__init__()
self.conv = nn.Conv1d(
channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False,
)
nn.init.zeros_(self.conv.weight)
self.conv.weight.data[:, 0, kernel_size // 2] = 1.0 # identity
def forward(self, x):
return self.conv(x)
class _ActivationWithGAFilter(nn.Module):
def __init__(self, activation: nn.Module, gafilter: GAFilter):
super().__init__()
self.activation = activation
self.gafilter = gafilter
def forward(self, x):
return self.gafilter(self.activation(x))
def inject_gafilters(vocoder: nn.Module, kernel_size: int = 9) -> int:
"""Inject GAFilter after each Activation1d in BigVGAN residual blocks.
Modifies vocoder in-place. GAFilter weights appear in vocoder.state_dict()
under resblocks.{i}.activations.{j}.gafilter.conv.weight — so a normal
load_state_dict call after injection will populate them correctly.
Returns the number of injected filters.
"""
count = 0
for resblock in getattr(vocoder, "resblocks", []):
activations = getattr(resblock, "activations", None)
if activations is None:
continue
for j in range(len(activations)):
act1d = activations[j]
act = getattr(act1d, "act", None)
if act is None:
continue
alpha = getattr(act, "alpha", None)
if alpha is None:
continue
channels = alpha.shape[0]
activations[j] = _ActivationWithGAFilter(act1d, GAFilter(channels, kernel_size))
count += 1
return count
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Utility helpers # Utility helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -342,6 +404,18 @@ class SelvaBigvganTrainer:
"Increase to 1e-2 for all_params to prevent catastrophic forgetting." "Increase to 1e-2 for all_params to prevent catastrophic forgetting."
), ),
}), }),
"use_gafilter": ("BOOLEAN", {
"default": True,
"tooltip": (
"Inject AF-Vocoder GAFilter (Interspeech 2025) after each Snake activation. "
"Adds a learnable depthwise FIR filter per channel, initialized as identity. "
"Trained alongside Snake alphas. Saved into the checkpoint for inference."
),
}),
"gafilter_kernel_size": ("INT", {
"default": 9, "min": 3, "max": 31, "step": 2,
"tooltip": "FIR filter length for GAFilter. Must be odd. Larger = wider frequency response control.",
}),
"lambda_phase": ("FLOAT", { "lambda_phase": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1, "default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
"tooltip": ( "tooltip": (
@@ -367,7 +441,8 @@ class SelvaBigvganTrainer:
} }
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size, def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""): segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
save_every, seed, discriminator_path=""):
import traceback import traceback
device = get_device() device = get_device()
@@ -475,7 +550,8 @@ class SelvaBigvganTrainer:
vocoder, mel_converter, clips, vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils, device, dtype, strategy, feature_utils,
segment_samples, sample_rate, segment_samples, sample_rate,
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase, train_mode, steps, lr, batch_size, lambda_l2sp,
use_gafilter, gafilter_kernel_size, lambda_phase,
save_every, seed, out_path, disc_path, pbar, save_every, seed, out_path, disc_path, pbar,
) )
except Exception as e: except Exception as e:
@@ -498,7 +574,8 @@ class SelvaBigvganTrainer:
def _do_train(vocoder, mel_converter, clips, def _do_train(vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils, device, dtype, strategy, feature_utils,
segment_samples, sample_rate, segment_samples, sample_rate,
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase, train_mode, steps, lr, batch_size, lambda_l2sp,
use_gafilter, gafilter_kernel_size, lambda_phase,
save_every, seed, out_path, disc_path, pbar): save_every, seed, out_path, disc_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active. """Execute training. Called in a fresh thread — no inference_mode active.
@@ -592,18 +669,25 @@ def _do_train(vocoder, mel_converter, clips,
if buf is not None: if buf is not None:
module._buffers[bname] = buf.clone() module._buffers[bname] = buf.clone()
# ── GAFilter injection (after inference-flag stripping) ──────────────────
# GAFilter params are fresh tensors — no inference flag to strip.
if use_gafilter:
n_gaf = inject_gafilters(vocoder, gafilter_kernel_size)
vocoder.to(device)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={gafilter_kernel_size}", flush=True)
# ── Training mode: select which parameters to train ────────────────────── # ── Training mode: select which parameters to train ──────────────────────
if train_mode == "snake_alpha_only": if train_mode == "snake_alpha_only":
alpha_params = [] alpha_params = []
for name, param in vocoder.named_parameters(): for name, param in vocoder.named_parameters():
if "alpha" in name: if "alpha" in name or (use_gafilter and "gafilter" in name):
param.requires_grad_(True) param.requires_grad_(True)
alpha_params.append(param) alpha_params.append(param)
else: else:
param.requires_grad_(False) param.requires_grad_(False)
n_trainable = sum(p.numel() for p in alpha_params) n_trainable = sum(p.numel() for p in alpha_params)
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params " print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
f"({len(alpha_params)} alpha tensors)", flush=True) f"({len(alpha_params)} tensors, gafilter={'yes' if use_gafilter else 'no'})", flush=True)
trainable_params = alpha_params trainable_params = alpha_params
else: # all_params else: # all_params
for param in vocoder.parameters(): for param in vocoder.parameters():
@@ -749,7 +833,11 @@ def _do_train(vocoder, mel_converter, clips,
if (step + 1) % save_every == 0 and (step + 1) < steps: if (step + 1) % save_every == 0 and (step + 1) < steps:
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
torch.save({"generator": vocoder.state_dict()}, str(step_path)) torch.save({
"generator": vocoder.state_dict(),
"has_gafilter": use_gafilter,
"gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9,
}, str(step_path))
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
vocoder.eval() vocoder.eval()
_save_sample(f"step{step+1}") _save_sample(f"step{step+1}")
@@ -762,7 +850,12 @@ def _do_train(vocoder, mel_converter, clips,
feature_utils.to("cpu") feature_utils.to("cpu")
soft_empty_cache() soft_empty_cache()
torch.save({"generator": vocoder.state_dict()}, str(out_path)) save_dict = {
print(f"\n[BigVGAN] Saved: {out_path}", flush=True) "generator": vocoder.state_dict(),
"has_gafilter": use_gafilter,
"gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9,
}
torch.save(save_dict, str(out_path))
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
_save_sample("final") _save_sample("final")
return str(out_path) return str(out_path)