fix: bypass @torch.inference_mode() on decode to preserve gradient chain
feature_utils.decode and autoencoder.decode are both decorated with @torch.inference_mode(), which unconditionally destroys grad_fn on all outputs — making loss.backward() fail with 'does not require grad'. Fix: call feature_utils.tod.vae.decode() directly, which has no decorator and is fully differentiable. Transpose matches the original wrapper signature. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -399,10 +399,13 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
x = x + dt * flow
|
||||
|
||||
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
|
||||
# Direct call — inference flags were stripped from all model weights
|
||||
# at the top of _do_optimize, so no checkpoint wrapper is needed.
|
||||
# feature_utils.decode and autoencoder.decode are both decorated with
|
||||
# @torch.inference_mode(), which destroys the gradient chain.
|
||||
# Bypass both wrappers and call vae.decode directly — it has no
|
||||
# inference_mode decorator and is fully differentiable.
|
||||
# The transpose matches feature_utils.decode: [B, T, C] → [B, C, T].
|
||||
x_un = net_generator.unnormalize(x)
|
||||
mel_gen = feature_utils.decode(x_un)
|
||||
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
|
||||
|
||||
# ── Style loss ───────────────────────────────────────────────────────
|
||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
||||
|
||||
Reference in New Issue
Block a user