diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 689a7de..0f6ca25 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -274,9 +274,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter, normalize, target_lufs, pbar): """Optimization loop — runs in a fresh thread (no inference_mode active).""" - # Strip inference flags from ref stats (came from main thread) - ref_mean = ref_mean.clone().detach() - ref_gram = ref_gram.clone().detach() + # Strip inference flags from ref stats (came from main thread) and cast to + # model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path); + # mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient → + # "Found dtype Float but expected BFloat16" in backward through bfloat16 ops. + ref_mean = ref_mean.clone().detach().to(dtype) + ref_gram = ref_gram.clone().detach().to(dtype) torch.manual_seed(seed)