debug: add driver-level VRAM reporting + offload video_enc

torch.cuda.memory_allocated only tracks PyTorch allocator. Added
torch.cuda.mem_get_info to see actual CUDA driver memory usage.
Also offload video_enc (TextSynch) which was missed in the original
offload — stays on GPU when strategy != offload_to_cpu.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 01:48:04 +02:00
parent 9af4bbdd91
commit 4297715a08
+13 -4
View File
@@ -776,8 +776,10 @@ class SelvaBigvganTrainer:
if device.type == "cuda":
alloc = torch.cuda.memory_allocated(device) / (1024**3)
resrv = torch.cuda.memory_reserved(device) / (1024**3)
print(f"[BigVGAN VRAM] {label}: {alloc:.2f} GiB allocated, "
f"{resrv:.2f} GiB reserved", flush=True)
free_cuda, total_cuda = torch.cuda.mem_get_info(device)
used_driver = (total_cuda - free_cuda) / (1024**3)
print(f"[BigVGAN VRAM] {label}: alloc={alloc:.2f} reserved={resrv:.2f} "
f"driver_used={used_driver:.2f} GiB", flush=True)
_vram_log("before unload")
comfy.model_management.unload_all_models()
@@ -791,6 +793,9 @@ class SelvaBigvganTrainer:
if "generator" in model:
model["generator"].to("cpu")
_vram_log("after generator.to(cpu)")
if "video_enc" in model:
model["video_enc"].to("cpu")
_vram_log("after video_enc.to(cpu)")
soft_empty_cache()
_vram_log("after soft_empty_cache")
@@ -1093,8 +1098,12 @@ def _do_train(vocoder, mel_converter, clips,
if device.type == "cuda":
alloc = torch.cuda.memory_allocated(device) / (1024**3)
resrv = torch.cuda.memory_reserved(device) / (1024**3)
print(f"[BigVGAN VRAM] before training: {alloc:.2f} GiB allocated, "
f"{resrv:.2f} GiB reserved", flush=True)
free_cuda, total_cuda = torch.cuda.mem_get_info(device)
used_driver = (total_cuda - free_cuda) / (1024**3)
print(f"[BigVGAN VRAM] before training: "
f"pytorch_alloc={alloc:.2f} GiB, pytorch_reserved={resrv:.2f} GiB, "
f"driver_used={used_driver:.2f} GiB, driver_total={total_cuda/(1024**3):.2f} GiB",
flush=True)
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
vocoder.train()