fix: strip inference tensor flags in DITTO optimizer
Two crash paths under "RuntimeError: Inference tensors cannot be saved for backward": 1. clip_f / sync_f loaded from main-thread inference_mode carry the inference flag. Clone them on entry to the worker thread so the conditions built from them are clean non-inference tensors. Also clone x after Phase 1 before the STE reconnection — Phase 1 runs under no_grad and produces outputs that may still carry the flag through the conditions path. 2. net_generator.unnormalize + feature_utils.decode called outside any checkpoint wrapper with requires_grad=True input. Backward tried to save inference-flagged model weights. Wrapped both calls in checkpoint(use_reentrant=False) so they recompute on backward instead of storing activations. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -280,8 +280,8 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
clip_f = features["clip_features"].to(device, dtype)
|
||||
sync_f = features["sync_features"].to(device, dtype)
|
||||
clip_f = features["clip_features"].to(device, dtype).clone()
|
||||
sync_f = features["sync_features"].to(device, dtype).clone()
|
||||
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
@@ -341,6 +341,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
|
||||
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
|
||||
# treated as identity for gradient purposes (truncated BPTT).
|
||||
#
|
||||
# x may carry an inference tensor flag from Phase 1 (derived from
|
||||
# conditions which were built outside inference_mode but may have
|
||||
# propagated the flag). .clone() strips it so the STE addition does
|
||||
# not try to save an inference tensor for backward.
|
||||
x = x.clone()
|
||||
x = x + (x0 - x0.detach())
|
||||
|
||||
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
|
||||
@@ -359,8 +365,19 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
x = x + dt * flow
|
||||
|
||||
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
|
||||
x_unnorm = net_generator.unnormalize(x)
|
||||
mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T]
|
||||
# Wrap unnormalize + decode in gradient checkpointing so PyTorch does
|
||||
# not try to save model weights for backward. The VAE / generator
|
||||
# weights are inference-flagged tensors (loaded in the main thread);
|
||||
# saving them for backward would raise "Inference tensors cannot be
|
||||
# saved for backward". checkpoint(use_reentrant=False) recomputes the
|
||||
# forward during backward instead of storing activations.
|
||||
def _unnorm_decode(x_in):
|
||||
x_un = net_generator.unnormalize(x_in)
|
||||
return feature_utils.decode(x_un)
|
||||
|
||||
mel_gen = torch.utils.checkpoint.checkpoint(
|
||||
_unnorm_decode, x, use_reentrant=False
|
||||
)
|
||||
|
||||
# ── Style loss ───────────────────────────────────────────────────────
|
||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
||||
|
||||
Reference in New Issue
Block a user