From 16e20b30ce4d435d2d0074f6395827258d8f4422 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 01:46:01 +0200 Subject: [PATCH] fix(bigvgan-trainer): cast audio to model dtype to match bf16 mel_converter buffers Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded from disk is float32, causing matmul dtype mismatch. Cast all audio tensors to model["dtype"] before passing to mel_converter/vocoder. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index ba1146f..26cae86 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -121,6 +121,7 @@ class SelvaBigvganTrainer: device = get_device() mode = model["mode"] + dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers feature_utils = model["feature_utils"] mel_converter = feature_utils.mel_converter strategy = model["strategy"] @@ -193,8 +194,8 @@ class SelvaBigvganTrainer: random.seed(seed) # Fixed reference segment for eval samples — always clip 0, start 0 - ref_clip = clips[0][:segment_samples].to(device) # [T] - ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel] + ref_clip = clips[0][:segment_samples].to(device, dtype) # [T] + ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel] def _save_sample(label): """Vocode the reference mel and save as .wav.""" @@ -231,7 +232,7 @@ class SelvaBigvganTrainer: start = random.randint(0, clip.shape[0] - segment_samples) batch.append(clip[start : start + segment_samples]) - target_flat = torch.stack(batch).to(device) # [B, T] + target_flat = torch.stack(batch).to(device, dtype) # [B, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T] # Fixed target mel (no grad needed here)