chore: remove debug VRAM logging
Training confirmed working — VRAM usage is normal backward-pass activation memory, not a leak. Removed all debug _vram_log and _vram calls. Kept the video_enc offload and torch.cuda.empty_cache fixes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -772,37 +772,21 @@ class SelvaBigvganTrainer:
|
|||||||
|
|
||||||
# Unload all other ComfyUI models (SelVA generator, etc.) to free VRAM
|
# Unload all other ComfyUI models (SelVA generator, etc.) to free VRAM
|
||||||
# before starting training. BigVGAN + discriminator need the headroom.
|
# before starting training. BigVGAN + discriminator need the headroom.
|
||||||
def _vram_log(label):
|
|
||||||
if device.type == "cuda":
|
|
||||||
alloc = torch.cuda.memory_allocated(device) / (1024**3)
|
|
||||||
resrv = torch.cuda.memory_reserved(device) / (1024**3)
|
|
||||||
free_cuda, total_cuda = torch.cuda.mem_get_info(device)
|
|
||||||
used_driver = (total_cuda - free_cuda) / (1024**3)
|
|
||||||
print(f"[BigVGAN VRAM] {label}: alloc={alloc:.2f} reserved={resrv:.2f} "
|
|
||||||
f"driver_used={used_driver:.2f} GiB", flush=True)
|
|
||||||
|
|
||||||
_vram_log("before unload")
|
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
_vram_log("after unload_all_models")
|
|
||||||
|
|
||||||
# Move EVERYTHING to CPU first, then bring back only what we need.
|
# Move EVERYTHING to CPU first, then bring back only what we need.
|
||||||
# ComfyUI may have loaded the full model to GPU; unload_all_models
|
# ComfyUI may have loaded the full model to GPU; unload_all_models
|
||||||
# doesn't always free model dicts passed between nodes.
|
# doesn't always free model dicts passed between nodes.
|
||||||
feature_utils.to("cpu")
|
feature_utils.to("cpu")
|
||||||
_vram_log("after feature_utils.to(cpu)")
|
|
||||||
if "generator" in model:
|
if "generator" in model:
|
||||||
model["generator"].to("cpu")
|
model["generator"].to("cpu")
|
||||||
_vram_log("after generator.to(cpu)")
|
|
||||||
if "video_enc" in model:
|
if "video_enc" in model:
|
||||||
model["video_enc"].to("cpu")
|
model["video_enc"].to("cpu")
|
||||||
_vram_log("after video_enc.to(cpu)")
|
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
_vram_log("after soft_empty_cache")
|
|
||||||
|
|
||||||
# Only move mel_converter to GPU — it's tiny and needed for training.
|
# Only move mel_converter to GPU — it's tiny and needed for training.
|
||||||
# _pregenerate_lora_mels handles its own device management for CLIP/tod.
|
# _pregenerate_lora_mels handles its own device management for CLIP/tod.
|
||||||
mel_converter.to(device)
|
mel_converter.to(device)
|
||||||
_vram_log("after mel_converter.to(device)")
|
|
||||||
|
|
||||||
# Pre-compute text CLIP embeddings in the main thread.
|
# Pre-compute text CLIP embeddings in the main thread.
|
||||||
# CLIP weights are inference tensors from ComfyUI loading — they only
|
# CLIP weights are inference tensors from ComfyUI loading — they only
|
||||||
@@ -1094,17 +1078,6 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
f"falling back to mel+STFT losses", flush=True)
|
f"falling back to mel+STFT losses", flush=True)
|
||||||
mpd = mrd = None
|
mpd = mrd = None
|
||||||
|
|
||||||
# VRAM snapshot before training loop
|
|
||||||
if device.type == "cuda":
|
|
||||||
alloc = torch.cuda.memory_allocated(device) / (1024**3)
|
|
||||||
resrv = torch.cuda.memory_reserved(device) / (1024**3)
|
|
||||||
free_cuda, total_cuda = torch.cuda.mem_get_info(device)
|
|
||||||
used_driver = (total_cuda - free_cuda) / (1024**3)
|
|
||||||
print(f"[BigVGAN VRAM] before training: "
|
|
||||||
f"pytorch_alloc={alloc:.2f} GiB, pytorch_reserved={resrv:.2f} GiB, "
|
|
||||||
f"driver_used={used_driver:.2f} GiB, driver_total={total_cuda/(1024**3):.2f} GiB",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
|
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
|
||||||
vocoder.train()
|
vocoder.train()
|
||||||
|
|
||||||
@@ -1126,11 +1099,6 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames "
|
print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames "
|
||||||
f"per {segment_samples} audio samples", flush=True)
|
f"per {segment_samples} audio samples", flush=True)
|
||||||
|
|
||||||
def _vram(label):
|
|
||||||
if device.type == "cuda" and step < 1:
|
|
||||||
a = torch.cuda.memory_allocated(device) / (1024**3)
|
|
||||||
print(f" [VRAM step0] {label}: {a:.2f} GiB", flush=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
if lora_mel_pairs:
|
if lora_mel_pairs:
|
||||||
@@ -1173,7 +1141,7 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
# Clean target mel for mel loss (always from clean audio)
|
# Clean target mel for mel loss (always from clean audio)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel]
|
target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel]
|
||||||
_vram("after target_mel")
|
|
||||||
|
|
||||||
# Gradient checkpointing: recompute BigVGAN activations during
|
# Gradient checkpointing: recompute BigVGAN activations during
|
||||||
# backward instead of storing them. The 512x upsampling stack
|
# backward instead of storing them. The 512x upsampling stack
|
||||||
@@ -1183,14 +1151,14 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
pred_wav = torch.utils.checkpoint.checkpoint(
|
pred_wav = torch.utils.checkpoint.checkpoint(
|
||||||
vocoder, input_mel.to(dtype), use_reentrant=False
|
vocoder, input_mel.to(dtype), use_reentrant=False
|
||||||
) # [B, 1, T_wav]
|
) # [B, 1, T_wav]
|
||||||
_vram("after vocoder forward")
|
|
||||||
|
|
||||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||||
pred_t = pred_wav[..., :T]
|
pred_t = pred_wav[..., :T]
|
||||||
target_t = target_wav[..., :T]
|
target_t = target_wav[..., :T]
|
||||||
|
|
||||||
# ── Compute loss ─────────────────────────────────────────────────
|
# ── Compute loss ─────────────────────────────────────────────────
|
||||||
_vram("before loss")
|
|
||||||
if mpd is not None and mrd is not None:
|
if mpd is not None and mrd is not None:
|
||||||
# Perceptual feature matching via frozen discriminators
|
# Perceptual feature matching via frozen discriminators
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -1236,10 +1204,10 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
l2sp_loss = l2sp_loss * lambda_l2sp
|
l2sp_loss = l2sp_loss * lambda_l2sp
|
||||||
|
|
||||||
loss = primary_loss + l2sp_loss
|
loss = primary_loss + l2sp_loss
|
||||||
_vram("after loss computation")
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
_vram("after backward")
|
|
||||||
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user