fix: cast discriminators to model dtype to match vocoder output
Discriminators are constructed as float32 but receive bfloat16 tensors from the vocoder. Cast to model dtype on load to prevent conv dtype mismatch in feature matching loss. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -954,8 +954,8 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. "
|
f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. "
|
||||||
f"Available keys: {list(ckpt_d.keys())}"
|
f"Available keys: {list(ckpt_d.keys())}"
|
||||||
)
|
)
|
||||||
mpd.to(device).eval()
|
mpd.to(device, dtype).eval()
|
||||||
mrd.to(device).eval()
|
mrd.to(device, dtype).eval()
|
||||||
for p in mpd.parameters():
|
for p in mpd.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
for p in mrd.parameters():
|
for p in mrd.parameters():
|
||||||
|
|||||||
Reference in New Issue
Block a user