fix: match mel dtype to vocoder in baseline sample generation
ref_mel is float32 (from mel_converter) but vocoder weights are bfloat16 before inference flag stripping. Cast mel to vocoder's dtype to prevent input/bias type mismatch during baseline sample save. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -915,8 +915,8 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
|
|
||||||
def _save_sample(label):
|
def _save_sample(label):
|
||||||
try:
|
try:
|
||||||
voc_device = next(vocoder.parameters()).device
|
voc_p = next(vocoder.parameters())
|
||||||
mel = ref_mel.to(voc_device)
|
mel = ref_mel.to(voc_p.device, voc_p.dtype)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav = vocoder(mel)
|
wav = vocoder(mel)
|
||||||
if wav.dim() == 2:
|
if wav.dim() == 2:
|
||||||
|
|||||||
Reference in New Issue
Block a user