fix: cast ref_mean/ref_gram to model dtype before loss computation
ref_mean and ref_gram are float32 (mel computed via cuFFT which requires float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to float32, producing a float32 loss. loss.backward() then pushes float32 gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'. Fix: clone().detach().to(dtype) at the start of _do_optimize. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -274,9 +274,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
normalize, target_lufs, pbar):
|
normalize, target_lufs, pbar):
|
||||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
||||||
|
|
||||||
# Strip inference flags from ref stats (came from main thread)
|
# Strip inference flags from ref stats (came from main thread) and cast to
|
||||||
ref_mean = ref_mean.clone().detach()
|
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
|
||||||
ref_gram = ref_gram.clone().detach()
|
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
|
||||||
|
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
|
||||||
|
ref_mean = ref_mean.clone().detach().to(dtype)
|
||||||
|
ref_gram = ref_gram.clone().detach().to(dtype)
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user