From 732df151b04d924521a2f7b7a8d63a7df70bb45a Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 17:48:41 +0200 Subject: [PATCH] fix: cast ref_mean/ref_gram to model dtype before loss computation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ref_mean and ref_gram are float32 (mel computed via cuFFT which requires float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to float32, producing a float32 loss. loss.backward() then pushes float32 gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'. Fix: clone().detach().to(dtype) at the start of _do_optimize. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_ditto_optimizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 689a7de..0f6ca25 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -274,9 +274,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter, normalize, target_lufs, pbar): """Optimization loop — runs in a fresh thread (no inference_mode active).""" - # Strip inference flags from ref stats (came from main thread) - ref_mean = ref_mean.clone().detach() - ref_gram = ref_gram.clone().detach() + # Strip inference flags from ref stats (came from main thread) and cast to + # model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path); + # mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient → + # "Found dtype Float but expected BFloat16" in backward through bfloat16 ops. + ref_mean = ref_mean.clone().detach().to(dtype) + ref_gram = ref_gram.clone().detach().to(dtype) torch.manual_seed(seed)