feat: add AF-Vocoder GAFilter to BigVGAN trainer and loader
Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN residual blocks. Initialized as identity so training starts from pretrained behaviour. - inject_gafilters() walks resblocks.*.activations and wraps each Activation1d with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically - Trained alongside Snake alphas in snake_alpha_only mode - Checkpoint saves has_gafilter + gafilter_kernel_size metadata - Loader detects metadata and injects before load_state_dict so weights populate correctly - Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ import torch
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
from .utils import SELVA_CATEGORY
|
||||||
|
from .selva_bigvgan_trainer import inject_gafilters
|
||||||
|
|
||||||
|
|
||||||
class SelvaBigvganLoader:
|
class SelvaBigvganLoader:
|
||||||
@@ -60,6 +61,11 @@ class SelvaBigvganLoader:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
|
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
|
||||||
|
|
||||||
|
if ckpt.get("has_gafilter", False):
|
||||||
|
kernel_size = ckpt.get("gafilter_kernel_size", 9)
|
||||||
|
n_gaf = inject_gafilters(vocoder, kernel_size)
|
||||||
|
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
|
||||||
|
|
||||||
vocoder.load_state_dict(ckpt["generator"])
|
vocoder.load_state_dict(ckpt["generator"])
|
||||||
vocoder.eval()
|
vocoder.eval()
|
||||||
|
|
||||||
|
|||||||
@@ -161,6 +161,68 @@ def _feature_matching_loss(fmaps_real, fmaps_gen):
|
|||||||
return loss / len(fmaps_real)
|
return loss / len(fmaps_real)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AF-Vocoder GAFilter (Interspeech 2025)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GAFilter(nn.Module):
|
||||||
|
"""Learnable per-channel depthwise FIR filter inserted after Snake activations.
|
||||||
|
|
||||||
|
Initialized as identity (delta at center) so training starts from the
|
||||||
|
pretrained vocoder's behaviour. Learns to shape the per-channel frequency
|
||||||
|
response to fix harmonic artifacts.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels: int, kernel_size: int = 9):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
channels, channels, kernel_size,
|
||||||
|
padding=kernel_size // 2, groups=channels, bias=False,
|
||||||
|
)
|
||||||
|
nn.init.zeros_(self.conv.weight)
|
||||||
|
self.conv.weight.data[:, 0, kernel_size // 2] = 1.0 # identity
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _ActivationWithGAFilter(nn.Module):
|
||||||
|
def __init__(self, activation: nn.Module, gafilter: GAFilter):
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
self.gafilter = gafilter
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.gafilter(self.activation(x))
|
||||||
|
|
||||||
|
|
||||||
|
def inject_gafilters(vocoder: nn.Module, kernel_size: int = 9) -> int:
|
||||||
|
"""Inject GAFilter after each Activation1d in BigVGAN residual blocks.
|
||||||
|
|
||||||
|
Modifies vocoder in-place. GAFilter weights appear in vocoder.state_dict()
|
||||||
|
under resblocks.{i}.activations.{j}.gafilter.conv.weight — so a normal
|
||||||
|
load_state_dict call after injection will populate them correctly.
|
||||||
|
|
||||||
|
Returns the number of injected filters.
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for resblock in getattr(vocoder, "resblocks", []):
|
||||||
|
activations = getattr(resblock, "activations", None)
|
||||||
|
if activations is None:
|
||||||
|
continue
|
||||||
|
for j in range(len(activations)):
|
||||||
|
act1d = activations[j]
|
||||||
|
act = getattr(act1d, "act", None)
|
||||||
|
if act is None:
|
||||||
|
continue
|
||||||
|
alpha = getattr(act, "alpha", None)
|
||||||
|
if alpha is None:
|
||||||
|
continue
|
||||||
|
channels = alpha.shape[0]
|
||||||
|
activations[j] = _ActivationWithGAFilter(act1d, GAFilter(channels, kernel_size))
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Utility helpers
|
# Utility helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -342,6 +404,18 @@ class SelvaBigvganTrainer:
|
|||||||
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
||||||
),
|
),
|
||||||
}),
|
}),
|
||||||
|
"use_gafilter": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": (
|
||||||
|
"Inject AF-Vocoder GAFilter (Interspeech 2025) after each Snake activation. "
|
||||||
|
"Adds a learnable depthwise FIR filter per channel, initialized as identity. "
|
||||||
|
"Trained alongside Snake alphas. Saved into the checkpoint for inference."
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
"gafilter_kernel_size": ("INT", {
|
||||||
|
"default": 9, "min": 3, "max": 31, "step": 2,
|
||||||
|
"tooltip": "FIR filter length for GAFilter. Must be odd. Larger = wider frequency response control.",
|
||||||
|
}),
|
||||||
"lambda_phase": ("FLOAT", {
|
"lambda_phase": ("FLOAT", {
|
||||||
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
|
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
|
||||||
"tooltip": (
|
"tooltip": (
|
||||||
@@ -367,7 +441,8 @@ class SelvaBigvganTrainer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||||
segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""):
|
segment_seconds, lambda_l2sp, use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||||
|
save_every, seed, discriminator_path=""):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
@@ -475,7 +550,8 @@ class SelvaBigvganTrainer:
|
|||||||
vocoder, mel_converter, clips,
|
vocoder, mel_converter, clips,
|
||||||
device, dtype, strategy, feature_utils,
|
device, dtype, strategy, feature_utils,
|
||||||
segment_samples, sample_rate,
|
segment_samples, sample_rate,
|
||||||
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||||
|
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||||
save_every, seed, out_path, disc_path, pbar,
|
save_every, seed, out_path, disc_path, pbar,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -498,7 +574,8 @@ class SelvaBigvganTrainer:
|
|||||||
def _do_train(vocoder, mel_converter, clips,
|
def _do_train(vocoder, mel_converter, clips,
|
||||||
device, dtype, strategy, feature_utils,
|
device, dtype, strategy, feature_utils,
|
||||||
segment_samples, sample_rate,
|
segment_samples, sample_rate,
|
||||||
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||||
|
use_gafilter, gafilter_kernel_size, lambda_phase,
|
||||||
save_every, seed, out_path, disc_path, pbar):
|
save_every, seed, out_path, disc_path, pbar):
|
||||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||||
|
|
||||||
@@ -592,18 +669,25 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
if buf is not None:
|
if buf is not None:
|
||||||
module._buffers[bname] = buf.clone()
|
module._buffers[bname] = buf.clone()
|
||||||
|
|
||||||
|
# ── GAFilter injection (after inference-flag stripping) ──────────────────
|
||||||
|
# GAFilter params are fresh tensors — no inference flag to strip.
|
||||||
|
if use_gafilter:
|
||||||
|
n_gaf = inject_gafilters(vocoder, gafilter_kernel_size)
|
||||||
|
vocoder.to(device)
|
||||||
|
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={gafilter_kernel_size}", flush=True)
|
||||||
|
|
||||||
# ── Training mode: select which parameters to train ──────────────────────
|
# ── Training mode: select which parameters to train ──────────────────────
|
||||||
if train_mode == "snake_alpha_only":
|
if train_mode == "snake_alpha_only":
|
||||||
alpha_params = []
|
alpha_params = []
|
||||||
for name, param in vocoder.named_parameters():
|
for name, param in vocoder.named_parameters():
|
||||||
if "alpha" in name:
|
if "alpha" in name or (use_gafilter and "gafilter" in name):
|
||||||
param.requires_grad_(True)
|
param.requires_grad_(True)
|
||||||
alpha_params.append(param)
|
alpha_params.append(param)
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
n_trainable = sum(p.numel() for p in alpha_params)
|
n_trainable = sum(p.numel() for p in alpha_params)
|
||||||
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
|
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
|
||||||
f"({len(alpha_params)} alpha tensors)", flush=True)
|
f"({len(alpha_params)} tensors, gafilter={'yes' if use_gafilter else 'no'})", flush=True)
|
||||||
trainable_params = alpha_params
|
trainable_params = alpha_params
|
||||||
else: # all_params
|
else: # all_params
|
||||||
for param in vocoder.parameters():
|
for param in vocoder.parameters():
|
||||||
@@ -749,7 +833,11 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
|
|
||||||
if (step + 1) % save_every == 0 and (step + 1) < steps:
|
if (step + 1) % save_every == 0 and (step + 1) < steps:
|
||||||
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
||||||
torch.save({"generator": vocoder.state_dict()}, str(step_path))
|
torch.save({
|
||||||
|
"generator": vocoder.state_dict(),
|
||||||
|
"has_gafilter": use_gafilter,
|
||||||
|
"gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9,
|
||||||
|
}, str(step_path))
|
||||||
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
||||||
vocoder.eval()
|
vocoder.eval()
|
||||||
_save_sample(f"step{step+1}")
|
_save_sample(f"step{step+1}")
|
||||||
@@ -762,7 +850,12 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
feature_utils.to("cpu")
|
feature_utils.to("cpu")
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
save_dict = {
|
||||||
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
"generator": vocoder.state_dict(),
|
||||||
|
"has_gafilter": use_gafilter,
|
||||||
|
"gafilter_kernel_size": gafilter_kernel_size if use_gafilter else 9,
|
||||||
|
}
|
||||||
|
torch.save(save_dict, str(out_path))
|
||||||
|
print(f"\n[BigVGAN] Saved: {out_path} gafilter={use_gafilter}", flush=True)
|
||||||
_save_sample("final")
|
_save_sample("final")
|
||||||
return str(out_path)
|
return str(out_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user