perf: gradient checkpointing on vocoder forward to reduce activation memory
BigVGAN's 512x upsampling stack stores huge intermediate activations for backward even in snake_alpha_only mode (only 5K trainable params, but activation graph runs through the full network after each snake op). Wrapping vocoder() in checkpoint(use_reentrant=False) recomputes activations during backward instead of storing them — ~2x compute cost, large reduction in peak VRAM. Should allow batch_size > 1 on 96 GB without OOM. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -636,7 +636,14 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||||
|
|
||||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
# Gradient checkpointing: recompute BigVGAN activations during
|
||||||
|
# backward instead of storing them. The 512x upsampling stack
|
||||||
|
# produces enormous intermediate tensors — checkpointing trades
|
||||||
|
# ~2x compute for a large reduction in activation memory, allowing
|
||||||
|
# batch_size > 1 without OOM.
|
||||||
|
pred_wav = torch.utils.checkpoint.checkpoint(
|
||||||
|
vocoder, target_mel, use_reentrant=False
|
||||||
|
) # [B, 1, T_wav]
|
||||||
|
|
||||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||||
pred_t = pred_wav[..., :T]
|
pred_t = pred_wav[..., :T]
|
||||||
|
|||||||
Reference in New Issue
Block a user