From 8166c5655207e03cce1332d6b0e8d43eb2166619 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 13:45:24 +0200 Subject: [PATCH] perf: gradient checkpointing on vocoder forward to reduce activation memory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BigVGAN's 512x upsampling stack stores huge intermediate activations for backward even in snake_alpha_only mode (only 5K trainable params, but activation graph runs through the full network after each snake op). Wrapping vocoder() in checkpoint(use_reentrant=False) recomputes activations during backward instead of storing them — ~2x compute cost, large reduction in peak VRAM. Should allow batch_size > 1 on 96 GB without OOM. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 16ab5a1..5d6707e 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -636,7 +636,14 @@ def _do_train(vocoder, mel_converter, clips, with torch.no_grad(): target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] - pred_wav = vocoder(target_mel) # [B, 1, T_wav] + # Gradient checkpointing: recompute BigVGAN activations during + # backward instead of storing them. The 512x upsampling stack + # produces enormous intermediate tensors — checkpointing trades + # ~2x compute for a large reduction in activation memory, allowing + # batch_size > 1 without OOM. + pred_wav = torch.utils.checkpoint.checkpoint( + vocoder, target_mel, use_reentrant=False + ) # [B, 1, T_wav] T = min(pred_wav.shape[-1], target_wav.shape[-1]) pred_t = pred_wav[..., :T]