732df151b0
ref_mean and ref_gram are float32 (mel computed via cuFFT which requires float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to float32, producing a float32 loss. loss.backward() then pushes float32 gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'. Fix: clone().detach().to(dtype) at the start of _do_optimize. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>