From 1f02d73a3ed4fb55c95b74ad951265611971431c Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 17:40:00 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20remove=20checkpoint=20wrapper=20on=20dec?= =?UTF-8?q?ode=20=E2=80=94=20direct=20call=20preserves=20grad=20chain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _unnorm_decode was wrapped in checkpoint(use_reentrant=False) to avoid saving inference-mode weight tensors during backward. Since _strip_inference() now cleans all params/buffers before any forward pass, the checkpoint is no longer needed and was silently breaking the gradient chain from mel_gen back to x0. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_ditto_optimizer.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 85916f2..1a93399 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -399,19 +399,10 @@ def _do_optimize(net_generator, feature_utils, mel_converter, x = x + dt * flow # ── Decode to mel (no vocoder — cheap) ────────────────────────────── - # Wrap unnormalize + decode in gradient checkpointing so PyTorch does - # not try to save model weights for backward. The VAE / generator - # weights are inference-flagged tensors (loaded in the main thread); - # saving them for backward would raise "Inference tensors cannot be - # saved for backward". checkpoint(use_reentrant=False) recomputes the - # forward during backward instead of storing activations. - def _unnorm_decode(x_in): - x_un = net_generator.unnormalize(x_in) - return feature_utils.decode(x_un) - - mel_gen = torch.utils.checkpoint.checkpoint( - _unnorm_decode, x, use_reentrant=False - ) + # Direct call — inference flags were stripped from all model weights + # at the top of _do_optimize, so no checkpoint wrapper is needed. + x_un = net_generator.unnormalize(x) + mel_gen = feature_utils.decode(x_un) # ── Style loss ─────────────────────────────────────────────────────── loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)