fix(bigvgan-trainer): cast audio to model dtype to match bf16 mel_converter buffers
Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded from disk is float32, causing matmul dtype mismatch. Cast all audio tensors to model["dtype"] before passing to mel_converter/vocoder. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -121,6 +121,7 @@ class SelvaBigvganTrainer:
|
||||
|
||||
device = get_device()
|
||||
mode = model["mode"]
|
||||
dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers
|
||||
feature_utils = model["feature_utils"]
|
||||
mel_converter = feature_utils.mel_converter
|
||||
strategy = model["strategy"]
|
||||
@@ -193,8 +194,8 @@ class SelvaBigvganTrainer:
|
||||
random.seed(seed)
|
||||
|
||||
# Fixed reference segment for eval samples — always clip 0, start 0
|
||||
ref_clip = clips[0][:segment_samples].to(device) # [T]
|
||||
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
ref_clip = clips[0][:segment_samples].to(device, dtype) # [T]
|
||||
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
|
||||
def _save_sample(label):
|
||||
"""Vocode the reference mel and save as .wav."""
|
||||
@@ -231,7 +232,7 @@ class SelvaBigvganTrainer:
|
||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||
batch.append(clip[start : start + segment_samples])
|
||||
|
||||
target_flat = torch.stack(batch).to(device) # [B, T]
|
||||
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
# Fixed target mel (no grad needed here)
|
||||
|
||||
Reference in New Issue
Block a user