fix: remove checkpoint wrapper on decode — direct call preserves grad chain
_unnorm_decode was wrapped in checkpoint(use_reentrant=False) to avoid saving inference-mode weight tensors during backward. Since _strip_inference() now cleans all params/buffers before any forward pass, the checkpoint is no longer needed and was silently breaking the gradient chain from mel_gen back to x0. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -399,19 +399,10 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
x = x + dt * flow
|
x = x + dt * flow
|
||||||
|
|
||||||
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
|
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
|
||||||
# Wrap unnormalize + decode in gradient checkpointing so PyTorch does
|
# Direct call — inference flags were stripped from all model weights
|
||||||
# not try to save model weights for backward. The VAE / generator
|
# at the top of _do_optimize, so no checkpoint wrapper is needed.
|
||||||
# weights are inference-flagged tensors (loaded in the main thread);
|
x_un = net_generator.unnormalize(x)
|
||||||
# saving them for backward would raise "Inference tensors cannot be
|
mel_gen = feature_utils.decode(x_un)
|
||||||
# saved for backward". checkpoint(use_reentrant=False) recomputes the
|
|
||||||
# forward during backward instead of storing activations.
|
|
||||||
def _unnorm_decode(x_in):
|
|
||||||
x_un = net_generator.unnormalize(x_in)
|
|
||||||
return feature_utils.decode(x_un)
|
|
||||||
|
|
||||||
mel_gen = torch.utils.checkpoint.checkpoint(
|
|
||||||
_unnorm_decode, x, use_reentrant=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Style loss ───────────────────────────────────────────────────────
|
# ── Style loss ───────────────────────────────────────────────────────
|
||||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
||||||
|
|||||||
Reference in New Issue
Block a user