fix: cast GAFilter to model dtype after injection
GAFilter conv weights are created as float32 but the rest of the vocoder is bfloat16. vocoder.to(device) missed the dtype cast, causing conv1d dtype mismatch when Snake bfloat16 output flows into GAFilter. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -895,7 +895,7 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
# GAFilter params are fresh tensors — no inference flag to strip.
|
||||
if use_gafilter:
|
||||
n_gaf = inject_gafilters(vocoder, gafilter_kernel_size)
|
||||
vocoder.to(device)
|
||||
vocoder.to(device, dtype)
|
||||
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={gafilter_kernel_size}", flush=True)
|
||||
|
||||
# ── Training mode: select which parameters to train ──────────────────────
|
||||
|
||||
Reference in New Issue
Block a user