fix: cast discriminators to model dtype to match vocoder output

Discriminators are constructed as float32 but receive bfloat16 tensors
from the vocoder. Cast to model dtype on load to prevent conv dtype
mismatch in feature matching loss.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:25:04 +02:00
parent 187b2e3169
commit 0854bd2638
+2 -2
View File
@@ -954,8 +954,8 @@ def _do_train(vocoder, mel_converter, clips,
f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. " f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. "
f"Available keys: {list(ckpt_d.keys())}" f"Available keys: {list(ckpt_d.keys())}"
) )
mpd.to(device).eval() mpd.to(device, dtype).eval()
mrd.to(device).eval() mrd.to(device, dtype).eval()
for p in mpd.parameters(): for p in mpd.parameters():
p.requires_grad_(False) p.requires_grad_(False)
for p in mrd.parameters(): for p in mrd.parameters():