diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 1a93399..689a7de 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -399,10 +399,13 @@ def _do_optimize(net_generator, feature_utils, mel_converter, x = x + dt * flow # ── Decode to mel (no vocoder — cheap) ────────────────────────────── - # Direct call — inference flags were stripped from all model weights - # at the top of _do_optimize, so no checkpoint wrapper is needed. + # feature_utils.decode and autoencoder.decode are both decorated with + # @torch.inference_mode(), which destroys the gradient chain. + # Bypass both wrappers and call vae.decode directly — it has no + # inference_mode decorator and is fully differentiable. + # The transpose matches feature_utils.decode: [B, T, C] → [B, C, T]. x_un = net_generator.unnormalize(x) - mel_gen = feature_utils.decode(x_un) + mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2)) # ── Style loss ─────────────────────────────────────────────────────── loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)