diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 00fb174..4036c27 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -280,8 +280,8 @@ def _do_optimize(net_generator, feature_utils, mel_converter, torch.manual_seed(seed) - clip_f = features["clip_features"].to(device, dtype) - sync_f = features["sync_features"].to(device, dtype) + clip_f = features["clip_features"].to(device, dtype).clone() + sync_f = features["sync_features"].to(device, dtype).clone() net_generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, @@ -341,6 +341,12 @@ def _do_optimize(net_generator, feature_utils, mel_converter, # propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad. # The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is # treated as identity for gradient purposes (truncated BPTT). + # + # x may carry an inference tensor flag from Phase 1 (derived from + # conditions which were built outside inference_mode but may have + # propagated the flag). .clone() strips it so the STE addition does + # not try to save an inference tensor for backward. + x = x.clone() x = x + (x0 - x0.detach()) # ── Phase 2: run last n_grad_steps with gradient + checkpointing ── @@ -359,8 +365,19 @@ def _do_optimize(net_generator, feature_utils, mel_converter, x = x + dt * flow # ── Decode to mel (no vocoder — cheap) ────────────────────────────── - x_unnorm = net_generator.unnormalize(x) - mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T] + # Wrap unnormalize + decode in gradient checkpointing so PyTorch does + # not try to save model weights for backward. The VAE / generator + # weights are inference-flagged tensors (loaded in the main thread); + # saving them for backward would raise "Inference tensors cannot be + # saved for backward". checkpoint(use_reentrant=False) recomputes the + # forward during backward instead of storing activations. + def _unnorm_decode(x_in): + x_un = net_generator.unnormalize(x_in) + return feature_utils.decode(x_un) + + mel_gen = torch.utils.checkpoint.checkpoint( + _unnorm_decode, x, use_reentrant=False + ) # ── Style loss ─────────────────────────────────────────────────────── loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)