From 89d6fccd280d83face9dda752af0580efd0017c1 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 01:35:54 +0200 Subject: [PATCH] debug: add per-operation VRAM logging in first training step Logs VRAM at: after target_mel, after vocoder forward, before loss, after loss computation, and after backward. Only logs for step 0 to avoid spam. Will identify which operation causes the 94 GiB spike. Co-Authored-By: Claude Opus 4.6 --- nodes/selva_bigvgan_trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index c26455e..699a701 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1111,6 +1111,11 @@ def _do_train(vocoder, mel_converter, clips, print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames " f"per {segment_samples} audio samples", flush=True) + def _vram(label): + if device.type == "cuda" and step < 1: + a = torch.cuda.memory_allocated(device) / (1024**3) + print(f" [VRAM step0] {label}: {a:.2f} GiB", flush=True) + try: for step in range(steps): if lora_mel_pairs: @@ -1153,6 +1158,7 @@ def _do_train(vocoder, mel_converter, clips, # Clean target mel for mel loss (always from clean audio) with torch.no_grad(): target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] + _vram("after target_mel") # Gradient checkpointing: recompute BigVGAN activations during # backward instead of storing them. The 512x upsampling stack @@ -1162,12 +1168,14 @@ def _do_train(vocoder, mel_converter, clips, pred_wav = torch.utils.checkpoint.checkpoint( vocoder, input_mel.to(dtype), use_reentrant=False ) # [B, 1, T_wav] + _vram("after vocoder forward") T = min(pred_wav.shape[-1], target_wav.shape[-1]) pred_t = pred_wav[..., :T] target_t = target_wav[..., :T] # ── Compute loss ───────────────────────────────────────────────── + _vram("before loss") if mpd is not None and mrd is not None: # Perceptual feature matching via frozen discriminators with torch.no_grad(): @@ -1213,8 +1221,10 @@ def _do_train(vocoder, mel_converter, clips, l2sp_loss = l2sp_loss * lambda_l2sp loss = primary_loss + l2sp_loss + _vram("after loss computation") optimizer.zero_grad() loss.backward() + _vram("after backward") torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optimizer.step()