fix: cast GAFilter to model dtype after injection

GAFilter conv weights are created as float32 but the rest of the vocoder
is bfloat16. vocoder.to(device) missed the dtype cast, causing conv1d
dtype mismatch when Snake bfloat16 output flows into GAFilter.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:24:11 +02:00
parent 608746ce7b
commit 187b2e3169
+1 -1
View File
@@ -895,7 +895,7 @@ def _do_train(vocoder, mel_converter, clips,
# GAFilter params are fresh tensors — no inference flag to strip.
if use_gafilter:
n_gaf = inject_gafilters(vocoder, gafilter_kernel_size)
vocoder.to(device)
vocoder.to(device, dtype)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={gafilter_kernel_size}", flush=True)
# ── Training mode: select which parameters to train ──────────────────────