fix(bigvgan-trainer): cast audio to model dtype to match bf16 mel_converter buffers
Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded from disk is float32, causing matmul dtype mismatch. Cast all audio tensors to model["dtype"] before passing to mel_converter/vocoder. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -121,6 +121,7 @@ class SelvaBigvganTrainer:
|
|||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
mode = model["mode"]
|
mode = model["mode"]
|
||||||
|
dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers
|
||||||
feature_utils = model["feature_utils"]
|
feature_utils = model["feature_utils"]
|
||||||
mel_converter = feature_utils.mel_converter
|
mel_converter = feature_utils.mel_converter
|
||||||
strategy = model["strategy"]
|
strategy = model["strategy"]
|
||||||
@@ -193,8 +194,8 @@ class SelvaBigvganTrainer:
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
# Fixed reference segment for eval samples — always clip 0, start 0
|
# Fixed reference segment for eval samples — always clip 0, start 0
|
||||||
ref_clip = clips[0][:segment_samples].to(device) # [T]
|
ref_clip = clips[0][:segment_samples].to(device, dtype) # [T]
|
||||||
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||||
|
|
||||||
def _save_sample(label):
|
def _save_sample(label):
|
||||||
"""Vocode the reference mel and save as .wav."""
|
"""Vocode the reference mel and save as .wav."""
|
||||||
@@ -231,7 +232,7 @@ class SelvaBigvganTrainer:
|
|||||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||||
batch.append(clip[start : start + segment_samples])
|
batch.append(clip[start : start + segment_samples])
|
||||||
|
|
||||||
target_flat = torch.stack(batch).to(device) # [B, T]
|
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
|
||||||
# Fixed target mel (no grad needed here)
|
# Fixed target mel (no grad needed here)
|
||||||
|
|||||||
Reference in New Issue
Block a user