fix: cast ref_mean/ref_gram to model dtype before loss computation

ref_mean and ref_gram are float32 (mel computed via cuFFT which requires
float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to
float32, producing a float32 loss. loss.backward() then pushes float32
gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'.

Fix: clone().detach().to(dtype) at the start of _do_optimize.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 17:48:41 +02:00
parent 817b75df49
commit 732df151b0
+6 -3
View File
@@ -274,9 +274,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread)
ref_mean = ref_mean.clone().detach()
ref_gram = ref_gram.clone().detach()
# Strip inference flags from ref stats (came from main thread) and cast to
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
ref_mean = ref_mean.clone().detach().to(dtype)
ref_gram = ref_gram.clone().detach().to(dtype)
torch.manual_seed(seed)