debug: add per-operation VRAM logging in first training step

Logs VRAM at: after target_mel, after vocoder forward, before loss,
after loss computation, and after backward. Only logs for step 0 to
avoid spam. Will identify which operation causes the 94 GiB spike.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 01:35:54 +02:00
parent bd84242fa1
commit 89d6fccd28
+10
View File
@@ -1111,6 +1111,11 @@ def _do_train(vocoder, mel_converter, clips,
print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames " print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames "
f"per {segment_samples} audio samples", flush=True) f"per {segment_samples} audio samples", flush=True)
def _vram(label):
if device.type == "cuda" and step < 1:
a = torch.cuda.memory_allocated(device) / (1024**3)
print(f" [VRAM step0] {label}: {a:.2f} GiB", flush=True)
try: try:
for step in range(steps): for step in range(steps):
if lora_mel_pairs: if lora_mel_pairs:
@@ -1153,6 +1158,7 @@ def _do_train(vocoder, mel_converter, clips,
# Clean target mel for mel loss (always from clean audio) # Clean target mel for mel loss (always from clean audio)
with torch.no_grad(): with torch.no_grad():
target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel]
_vram("after target_mel")
# Gradient checkpointing: recompute BigVGAN activations during # Gradient checkpointing: recompute BigVGAN activations during
# backward instead of storing them. The 512x upsampling stack # backward instead of storing them. The 512x upsampling stack
@@ -1162,12 +1168,14 @@ def _do_train(vocoder, mel_converter, clips,
pred_wav = torch.utils.checkpoint.checkpoint( pred_wav = torch.utils.checkpoint.checkpoint(
vocoder, input_mel.to(dtype), use_reentrant=False vocoder, input_mel.to(dtype), use_reentrant=False
) # [B, 1, T_wav] ) # [B, 1, T_wav]
_vram("after vocoder forward")
T = min(pred_wav.shape[-1], target_wav.shape[-1]) T = min(pred_wav.shape[-1], target_wav.shape[-1])
pred_t = pred_wav[..., :T] pred_t = pred_wav[..., :T]
target_t = target_wav[..., :T] target_t = target_wav[..., :T]
# ── Compute loss ───────────────────────────────────────────────── # ── Compute loss ─────────────────────────────────────────────────
_vram("before loss")
if mpd is not None and mrd is not None: if mpd is not None and mrd is not None:
# Perceptual feature matching via frozen discriminators # Perceptual feature matching via frozen discriminators
with torch.no_grad(): with torch.no_grad():
@@ -1213,8 +1221,10 @@ def _do_train(vocoder, mel_converter, clips,
l2sp_loss = l2sp_loss * lambda_l2sp l2sp_loss = l2sp_loss * lambda_l2sp
loss = primary_loss + l2sp_loss loss = primary_loss + l2sp_loss
_vram("after loss computation")
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
_vram("after backward")
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step() optimizer.step()