diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index c26455e..699a701 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1111,6 +1111,11 @@ def _do_train(vocoder, mel_converter, clips, print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames " f"per {segment_samples} audio samples", flush=True) + def _vram(label): + if device.type == "cuda" and step < 1: + a = torch.cuda.memory_allocated(device) / (1024**3) + print(f" [VRAM step0] {label}: {a:.2f} GiB", flush=True) + try: for step in range(steps): if lora_mel_pairs: @@ -1153,6 +1158,7 @@ def _do_train(vocoder, mel_converter, clips, # Clean target mel for mel loss (always from clean audio) with torch.no_grad(): target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] + _vram("after target_mel") # Gradient checkpointing: recompute BigVGAN activations during # backward instead of storing them. The 512x upsampling stack @@ -1162,12 +1168,14 @@ def _do_train(vocoder, mel_converter, clips, pred_wav = torch.utils.checkpoint.checkpoint( vocoder, input_mel.to(dtype), use_reentrant=False ) # [B, 1, T_wav] + _vram("after vocoder forward") T = min(pred_wav.shape[-1], target_wav.shape[-1]) pred_t = pred_wav[..., :T] target_t = target_wav[..., :T] # ── Compute loss ───────────────────────────────────────────────── + _vram("before loss") if mpd is not None and mrd is not None: # Perceptual feature matching via frozen discriminators with torch.no_grad(): @@ -1213,8 +1221,10 @@ def _do_train(vocoder, mel_converter, clips, l2sp_loss = l2sp_loss * lambda_l2sp loss = primary_loss + l2sp_loss + _vram("after loss computation") optimizer.zero_grad() loss.backward() + _vram("after backward") torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optimizer.step()