feat: add AF-Vocoder GAFilter to BigVGAN trainer and loader
Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN residual blocks. Initialized as identity so training starts from pretrained behaviour. - inject_gafilters() walks resblocks.*.activations and wraps each Activation1d with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically - Trained alongside Snake alphas in snake_alpha_only mode - Checkpoint saves has_gafilter + gafilter_kernel_size metadata - Loader detects metadata and injects before load_state_dict so weights populate correctly - Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -161,6 +161,68 @@ def _feature_matching_loss(fmaps_real, fmaps_gen):
|
||||
return loss / len(fmaps_real)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AF-Vocoder GAFilter (Interspeech 2025)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GAFilter(nn.Module):
|
||||
"""Learnable per-channel depthwise FIR filter inserted after Snake activations.
|
||||
|
||||
Initialized as identity (delta at center) so training starts from the
|
||||
pretrained vocoder's behaviour. Learns to shape the per-channel frequency
|
||||
response to fix harmonic artifacts.
|
||||
"""
|
||||
def __init__(self, channels: int, kernel_size: int = 9):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
channels, channels, kernel_size,
|
||||
padding=kernel_size // 2, groups=channels, bias=False,
|
||||
)
|
||||
nn.init.zeros_(self.conv.weight)
|
||||
self.conv.weight.data[:, 0, kernel_size // 2] = 1.0 # identity
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class _ActivationWithGAFilter(nn.Module):
|
||||
def __init__(self, activation: nn.Module, gafilter: GAFilter):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.gafilter = gafilter
|
||||
|
||||
def forward(self, x):
|
||||
return self.gafilter(self.activation(x))
|
||||
|
||||
|
||||
def inject_gafilters(vocoder: nn.Module, kernel_size: int = 9) -> int:
|
||||
"""Inject GAFilter after each Activation1d in BigVGAN residual blocks.
|
||||
|
||||
Modifies vocoder in-place. GAFilter weights appear in vocoder.state_dict()
|
||||
under resblocks.{i}.activations.{j}.gafilter.conv.weight — so a normal
|
||||
load_state_dict call after injection will populate them correctly.
|
||||
|
||||
Returns the number of injected filters.
|
||||
"""
|
||||
count = 0
|
||||
for resblock in getattr(vocoder, "resblocks", []):
|
||||
activations = getattr(resblock, "activations", None)
|
||||
if activations is None:
|
||||
continue
|
||||
for j in range(len(activations)):
|
||||
act1d = activations[j]
|
||||
act = getattr(act1d, "act", None)
|
||||
if act is None:
|
||||
continue
|
||||
alpha = getattr(act, "alpha", None)
|
||||
if alpha is None:
|
||||
continue
|
||||
channels = alpha.shape[0]
|
||||
activations[j] = _ActivationWithGAFilter(act1d, GAFilter(channels, kernel_size))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -342,6 +404,18 @@ class SelvaBigvganTrainer:
|
||||
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
||||
),
|
||||
}),
|
||||
"use_gafilter": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": (
|
||||
"Inject AF-Vocoder GAFilter (Interspeech 2025) after each Snake activation. "
|
||||
"Adds a learnable depthwise FIR filter per channel, initialized as identity. "
|
||||
"Trained alongside Snake alphas. Saved into the checkpoint for inference."
|
||||
),
|
||||
}),
|
||||
"gafilter_kernel_size": ("INT", {
|
||||
"default": 9, "min": 3, "max": 31, "step": 2,
|
||||
"tooltip": "FIR filter length for GAFilter. Must be odd. Larger = wider frequency response control.",
|
||||
}),
|
||||
"lambda_phase": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
|
||||
"tooltip": (
|
||||
@@ -367,7 +441,8 @@ class SelvaBigvganTrainer:
|
||||
}
|
||||
|
||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||
segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""):
|
||||
segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||
save_every, seed, discriminator_path=""):
|
||||
import traceback
|
||||
|
||||
device = get_device()
|
||||
@@ -475,7 +550,8 @@ class SelvaBigvganTrainer:
|
||||
vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||
save_every, seed, out_path, disc_path, pbar,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -498,7 +574,8 @@ class SelvaBigvganTrainer:
|
||||
def _do_train(vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||
save_every, seed, out_path, disc_path, pbar):
|
||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||
|
||||
@@ -592,18 +669,25 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
if buf is not None:
|
||||
module._buffers[bname] = buf.clone()
|
||||
|
||||
# ── GAFilter injection (after inference-flag stripping) ──────────────────
|
||||
# GAFilter params are fresh tensors — no inference flag to strip.
|
||||
if use_gafilter:
|
||||
n_gaf = inject_gafilters(vocoder, gafilter_kernel_size)
|
||||
vocoder.to(device)
|
||||
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={gafilter_kernel_size}", flush=True)
|
||||
|
||||
# ── Training mode: select which parameters to train ──────────────────────
|
||||
if train_mode == "snake_alpha_only":
|
||||
alpha_params = []
|
||||
for name, param in vocoder.named_parameters():
|
||||
if "alpha" in name:
|
||||
if "alpha" in name or (use_gafilter and "gafilter" in name):
|
||||
param.requires_grad_(True)
|
||||
alpha_params.append(param)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
n_trainable = sum(p.numel() for p in alpha_params)
|
||||
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
|
||||
f"({len(alpha_params)} alpha tensors)", flush=True)
|
||||
f"({len(alpha_params)} tensors, gafilter={'yes' if use_gafilter else 'no'})", flush=True)
|
||||
trainable_params = alpha_params
|
||||
else: # all_params
|
||||
for param in vocoder.parameters():
|
||||
@@ -749,7 +833,11 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
|
||||
if (step + 1) % save_every == 0 and (step + 1) < steps:
|
||||
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
||||
torch.save({"generator": vocoder.state_dict()}, str(step_path))
|
||||
torch.save({
|
||||
"generator": vocoder.state_dict(),
|
||||
"has_gafilter": use_gafilter,
|
||||
"gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9,
|
||||
}, str(step_path))
|
||||
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
||||
vocoder.eval()
|
||||
_save_sample(f"step{step+1}")
|
||||
@@ -762,7 +850,12 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
feature_utils.to("cpu")
|
||||
soft_empty_cache()
|
||||
|
||||
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
||||
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
||||
save_dict = {
|
||||
"generator": vocoder.state_dict(),
|
||||
"has_gafilter": use_gafilter,
|
||||
"gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9,
|
||||
}
|
||||
torch.save(save_dict, str(out_path))
|
||||
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
|
||||
_save_sample("final")
|
||||
return str(out_path)
|
||||
|
||||
Reference in New Issue
Block a user