diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 39954c6..01853bc 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -772,37 +772,21 @@ class SelvaBigvganTrainer: # Unload all other ComfyUI models (SelVA generator, etc.) to free VRAM # before starting training. BigVGAN + discriminator need the headroom. - def _vram_log(label): - if device.type == "cuda": - alloc = torch.cuda.memory_allocated(device) / (1024**3) - resrv = torch.cuda.memory_reserved(device) / (1024**3) - free_cuda, total_cuda = torch.cuda.mem_get_info(device) - used_driver = (total_cuda - free_cuda) / (1024**3) - print(f"[BigVGAN VRAM] {label}: alloc={alloc:.2f} reserved={resrv:.2f} " - f"driver_used={used_driver:.2f} GiB", flush=True) - - _vram_log("before unload") comfy.model_management.unload_all_models() - _vram_log("after unload_all_models") # Move EVERYTHING to CPU first, then bring back only what we need. # ComfyUI may have loaded the full model to GPU; unload_all_models # doesn't always free model dicts passed between nodes. feature_utils.to("cpu") - _vram_log("after feature_utils.to(cpu)") if "generator" in model: model["generator"].to("cpu") - _vram_log("after generator.to(cpu)") if "video_enc" in model: model["video_enc"].to("cpu") - _vram_log("after video_enc.to(cpu)") soft_empty_cache() - _vram_log("after soft_empty_cache") # Only move mel_converter to GPU — it's tiny and needed for training. # _pregenerate_lora_mels handles its own device management for CLIP/tod. mel_converter.to(device) - _vram_log("after mel_converter.to(device)") # Pre-compute text CLIP embeddings in the main thread. # CLIP weights are inference tensors from ComfyUI loading — they only @@ -1094,17 +1078,6 @@ def _do_train(vocoder, mel_converter, clips, f"falling back to mel+STFT losses", flush=True) mpd = mrd = None - # VRAM snapshot before training loop - if device.type == "cuda": - alloc = torch.cuda.memory_allocated(device) / (1024**3) - resrv = torch.cuda.memory_reserved(device) / (1024**3) - free_cuda, total_cuda = torch.cuda.mem_get_info(device) - used_driver = (total_cuda - free_cuda) / (1024**3) - print(f"[BigVGAN VRAM] before training: " - f"pytorch_alloc={alloc:.2f} GiB, pytorch_reserved={resrv:.2f} GiB, " - f"driver_used={used_driver:.2f} GiB, driver_total={total_cuda/(1024**3):.2f} GiB", - flush=True) - optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99)) vocoder.train() @@ -1126,11 +1099,6 @@ def _do_train(vocoder, mel_converter, clips, print(f"[BigVGAN] LoRA mel cropping: {_mel_segment} mel frames " f"per {segment_samples} audio samples", flush=True) - def _vram(label): - if device.type == "cuda" and step < 1: - a = torch.cuda.memory_allocated(device) / (1024**3) - print(f" [VRAM step0] {label}: {a:.2f} GiB", flush=True) - try: for step in range(steps): if lora_mel_pairs: @@ -1173,7 +1141,7 @@ def _do_train(vocoder, mel_converter, clips, # Clean target mel for mel loss (always from clean audio) with torch.no_grad(): target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] - _vram("after target_mel") + # Gradient checkpointing: recompute BigVGAN activations during # backward instead of storing them. The 512x upsampling stack @@ -1183,14 +1151,14 @@ def _do_train(vocoder, mel_converter, clips, pred_wav = torch.utils.checkpoint.checkpoint( vocoder, input_mel.to(dtype), use_reentrant=False ) # [B, 1, T_wav] - _vram("after vocoder forward") + T = min(pred_wav.shape[-1], target_wav.shape[-1]) pred_t = pred_wav[..., :T] target_t = target_wav[..., :T] # ── Compute loss ───────────────────────────────────────────────── - _vram("before loss") + if mpd is not None and mrd is not None: # Perceptual feature matching via frozen discriminators with torch.no_grad(): @@ -1236,10 +1204,10 @@ def _do_train(vocoder, mel_converter, clips, l2sp_loss = l2sp_loss * lambda_l2sp loss = primary_loss + l2sp_loss - _vram("after loss computation") + optimizer.zero_grad() loss.backward() - _vram("after backward") + torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optimizer.step()