fix: strip inference tensor flags in DITTO optimizer

Two crash paths under "RuntimeError: Inference tensors cannot be
saved for backward":

1. clip_f / sync_f loaded from main-thread inference_mode carry the
   inference flag. Clone them on entry to the worker thread so the
   conditions built from them are clean non-inference tensors.
   Also clone x after Phase 1 before the STE reconnection — Phase 1
   runs under no_grad and produces outputs that may still carry the
   flag through the conditions path.

2. net_generator.unnormalize + feature_utils.decode called outside
   any checkpoint wrapper with requires_grad=True input. Backward
   tried to save inference-flagged model weights. Wrapped both calls
   in checkpoint(use_reentrant=False) so they recompute on backward
   instead of storing activations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 12:18:20 +02:00
parent 211494a91c
commit 357b875e5e
+21 -4
View File
@@ -280,8 +280,8 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype)
sync_f = features["sync_features"].to(device, dtype)
clip_f = features["clip_features"].to(device, dtype).clone()
sync_f = features["sync_features"].to(device, dtype).clone()
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
@@ -341,6 +341,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
# treated as identity for gradient purposes (truncated BPTT).
#
# x may carry an inference tensor flag from Phase 1 (derived from
# conditions which were built outside inference_mode but may have
# propagated the flag). .clone() strips it so the STE addition does
# not try to save an inference tensor for backward.
x = x.clone()
x = x + (x0 - x0.detach())
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
@@ -359,8 +365,19 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
x = x + dt * flow
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
x_unnorm = net_generator.unnormalize(x)
mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T]
# Wrap unnormalize + decode in gradient checkpointing so PyTorch does
# not try to save model weights for backward. The VAE / generator
# weights are inference-flagged tensors (loaded in the main thread);
# saving them for backward would raise "Inference tensors cannot be
# saved for backward". checkpoint(use_reentrant=False) recomputes the
# forward during backward instead of storing activations.
def _unnorm_decode(x_in):
x_un = net_generator.unnormalize(x_in)
return feature_utils.decode(x_un)
mel_gen = torch.utils.checkpoint.checkpoint(
_unnorm_decode, x, use_reentrant=False
)
# ── Style loss ───────────────────────────────────────────────────────
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)