From cb9a1eef01760bd872a9488e5651adae2ebc368f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 00:44:38 +0200 Subject: [PATCH] fix: stop loading full feature_utils to GPU before training feature_utils.to(device) was loading CLIP ViT-H, synchformer, T5, VAE, and vocoder (~90 GiB) to GPU for the entire training run. Now only mel_converter (tiny) is moved to GPU. Pre-generation manages its own device placement: temporarily moves CLIP and tod to GPU, then moves them back when done. This frees ~90 GiB for the backward pass. Co-Authored-By: Claude Opus 4.6 --- nodes/selva_bigvgan_trainer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 0ec9ee7..bc6074a 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -487,10 +487,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) rng = torch.Generator(device=device).manual_seed(seed) - # Move VAE+vocoder to device for decode + # Move only the components we need to GPU for generation: + # - tod (VAE+vocoder) for decode + # - clip_model for encode_text_clip tod = feature_utils.tod tod_orig_dev = next(tod.parameters()).device tod.to(device) + clip_model = feature_utils.clip_model + if clip_model is not None: + clip_orig_dev = next(clip_model.parameters()).device + clip_model.to(device) pairs = [] try: @@ -565,6 +571,8 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, finally: tod.to(tod_orig_dev) + if clip_model is not None: + clip_model.to(clip_orig_dev) del generator soft_empty_cache() @@ -777,10 +785,10 @@ class SelvaBigvganTrainer: comfy.model_management.unload_all_models() soft_empty_cache() - if strategy == "offload_to_cpu": - feature_utils.to(device) - soft_empty_cache() - + # Only move mel_converter to GPU — it's tiny and needed for training. + # The rest of feature_utils (CLIP, synchformer, T5, VAE) stays on CPU; + # _pregenerate_lora_mels handles its own device management for the parts + # it needs temporarily. mel_converter.to(device) pbar = comfy.utils.ProgressBar(steps)