From 817b75df49874d83c0c57e844f2c43dc1751c576 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 17:44:35 +0200 Subject: [PATCH] fix: bypass @torch.inference_mode() on decode to preserve gradient chain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feature_utils.decode and autoencoder.decode are both decorated with @torch.inference_mode(), which unconditionally destroys grad_fn on all outputs — making loss.backward() fail with 'does not require grad'. Fix: call feature_utils.tod.vae.decode() directly, which has no decorator and is fully differentiable. Transpose matches the original wrapper signature. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_ditto_optimizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 1a93399..689a7de 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -399,10 +399,13 @@ def _do_optimize(net_generator, feature_utils, mel_converter, x = x + dt * flow # ── Decode to mel (no vocoder — cheap) ────────────────────────────── - # Direct call — inference flags were stripped from all model weights - # at the top of _do_optimize, so no checkpoint wrapper is needed. + # feature_utils.decode and autoencoder.decode are both decorated with + # @torch.inference_mode(), which destroys the gradient chain. + # Bypass both wrappers and call vae.decode directly — it has no + # inference_mode decorator and is fully differentiable. + # The transpose matches feature_utils.decode: [B, T, C] → [B, C, T]. x_un = net_generator.unnormalize(x) - mel_gen = feature_utils.decode(x_un) + mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2)) # ── Style loss ─────────────────────────────────────────────────────── loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)