fix: match mel dtype to vocoder in baseline sample generation
ref_mel is float32 (from mel_converter) but vocoder weights are bfloat16 before inference flag stripping. Cast mel to vocoder's dtype to prevent input/bias type mismatch during baseline sample save. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -915,8 +915,8 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
|
||||
def _save_sample(label):
|
||||
try:
|
||||
voc_device = next(vocoder.parameters()).device
|
||||
mel = ref_mel.to(voc_device)
|
||||
voc_p = next(vocoder.parameters())
|
||||
mel = ref_mel.to(voc_p.device, voc_p.dtype)
|
||||
with torch.no_grad():
|
||||
wav = vocoder(mel)
|
||||
if wav.dim() == 2:
|
||||
|
||||
Reference in New Issue
Block a user