fix(bigvgan-trainer): cast audio to model dtype to match bf16 mel_converter buffers

Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded
from disk is float32, causing matmul dtype mismatch. Cast all audio
tensors to model["dtype"] before passing to mel_converter/vocoder.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:46:01 +02:00
parent ea7dfed27a
commit 16e20b30ce
+4 -3
View File
@@ -121,6 +121,7 @@ class SelvaBigvganTrainer:
device = get_device() device = get_device()
mode = model["mode"] mode = model["mode"]
dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers
feature_utils = model["feature_utils"] feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter mel_converter = feature_utils.mel_converter
strategy = model["strategy"] strategy = model["strategy"]
@@ -193,8 +194,8 @@ class SelvaBigvganTrainer:
random.seed(seed) random.seed(seed)
# Fixed reference segment for eval samples — always clip 0, start 0 # Fixed reference segment for eval samples — always clip 0, start 0
ref_clip = clips[0][:segment_samples].to(device) # [T] ref_clip = clips[0][:segment_samples].to(device, dtype) # [T]
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel] ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
def _save_sample(label): def _save_sample(label):
"""Vocode the reference mel and save as .wav.""" """Vocode the reference mel and save as .wav."""
@@ -231,7 +232,7 @@ class SelvaBigvganTrainer:
start = random.randint(0, clip.shape[0] - segment_samples) start = random.randint(0, clip.shape[0] - segment_samples)
batch.append(clip[start : start + segment_samples]) batch.append(clip[start : start + segment_samples])
target_flat = torch.stack(batch).to(device) # [B, T] target_flat = torch.stack(batch).to(device, dtype) # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T]
# Fixed target mel (no grad needed here) # Fixed target mel (no grad needed here)