From 357b875e5e92083786a78aaeec948598dbe97c70 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 12:18:20 +0200 Subject: [PATCH] fix: strip inference tensor flags in DITTO optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two crash paths under "RuntimeError: Inference tensors cannot be saved for backward": 1. clip_f / sync_f loaded from main-thread inference_mode carry the inference flag. Clone them on entry to the worker thread so the conditions built from them are clean non-inference tensors. Also clone x after Phase 1 before the STE reconnection — Phase 1 runs under no_grad and produces outputs that may still carry the flag through the conditions path. 2. net_generator.unnormalize + feature_utils.decode called outside any checkpoint wrapper with requires_grad=True input. Backward tried to save inference-flagged model weights. Wrapped both calls in checkpoint(use_reentrant=False) so they recompute on backward instead of storing activations. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_ditto_optimizer.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 00fb174..4036c27 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -280,8 +280,8 @@ def _do_optimize(net_generator, feature_utils, mel_converter, torch.manual_seed(seed) - clip_f = features["clip_features"].to(device, dtype) - sync_f = features["sync_features"].to(device, dtype) + clip_f = features["clip_features"].to(device, dtype).clone() + sync_f = features["sync_features"].to(device, dtype).clone() net_generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, @@ -341,6 +341,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter, # propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad. # The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is # treated as identity for gradient purposes (truncated BPTT). + # + # x may carry an inference tensor flag from Phase 1 (derived from + # conditions which were built outside inference_mode but may have + # propagated the flag). .clone() strips it so the STE addition does + # not try to save an inference tensor for backward. + x = x.clone() x = x + (x0 - x0.detach()) # ── Phase 2: run last n_grad_steps with gradient + checkpointing ── @@ -359,8 +365,19 @@ def _do_optimize(net_generator, feature_utils, mel_converter, x = x + dt * flow # ── Decode to mel (no vocoder — cheap) ────────────────────────────── - x_unnorm = net_generator.unnormalize(x) - mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T] + # Wrap unnormalize + decode in gradient checkpointing so PyTorch does + # not try to save model weights for backward. The VAE / generator + # weights are inference-flagged tensors (loaded in the main thread); + # saving them for backward would raise "Inference tensors cannot be + # saved for backward". checkpoint(use_reentrant=False) recomputes the + # forward during backward instead of storing activations. + def _unnorm_decode(x_in): + x_un = net_generator.unnormalize(x_in) + return feature_utils.decode(x_un) + + mel_gen = torch.utils.checkpoint.checkpoint( + _unnorm_decode, x, use_reentrant=False + ) # ── Style loss ─────────────────────────────────────────────────────── loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)