perf: gradient checkpointing on vocoder forward to reduce activation memory

BigVGAN's 512x upsampling stack stores huge intermediate activations for
backward even in snake_alpha_only mode (only 5K trainable params, but
activation graph runs through the full network after each snake op).

Wrapping vocoder() in checkpoint(use_reentrant=False) recomputes activations
during backward instead of storing them — ~2x compute cost, large reduction
in peak VRAM. Should allow batch_size > 1 on 96 GB without OOM.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 13:45:24 +02:00
parent eece79ccae
commit 8166c56552
+8 -1
View File
@@ -636,7 +636,14 @@ def _do_train(vocoder, mel_converter, clips,
with torch.no_grad(): with torch.no_grad():
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
pred_wav = vocoder(target_mel) # [B, 1, T_wav] # Gradient checkpointing: recompute BigVGAN activations during
# backward instead of storing them. The 512x upsampling stack
# produces enormous intermediate tensors — checkpointing trades
# ~2x compute for a large reduction in activation memory, allowing
# batch_size > 1 without OOM.
pred_wav = torch.utils.checkpoint.checkpoint(
vocoder, target_mel, use_reentrant=False
) # [B, 1, T_wav]
T = min(pred_wav.shape[-1], target_wav.shape[-1]) T = min(pred_wav.shape[-1], target_wav.shape[-1])
pred_t = pred_wav[..., :T] pred_t = pred_wav[..., :T]