fix: stop loading full feature_utils to GPU before training
feature_utils.to(device) was loading CLIP ViT-H, synchformer, T5, VAE, and vocoder (~90 GiB) to GPU for the entire training run. Now only mel_converter (tiny) is moved to GPU. Pre-generation manages its own device placement: temporarily moves CLIP and tod to GPU, then moves them back when done. This frees ~90 GiB for the backward pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -487,10 +487,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||
rng = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
# Move VAE+vocoder to device for decode
|
||||
# Move only the components we need to GPU for generation:
|
||||
# - tod (VAE+vocoder) for decode
|
||||
# - clip_model for encode_text_clip
|
||||
tod = feature_utils.tod
|
||||
tod_orig_dev = next(tod.parameters()).device
|
||||
tod.to(device)
|
||||
clip_model = feature_utils.clip_model
|
||||
if clip_model is not None:
|
||||
clip_orig_dev = next(clip_model.parameters()).device
|
||||
clip_model.to(device)
|
||||
|
||||
pairs = []
|
||||
try:
|
||||
@@ -565,6 +571,8 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
|
||||
finally:
|
||||
tod.to(tod_orig_dev)
|
||||
if clip_model is not None:
|
||||
clip_model.to(clip_orig_dev)
|
||||
del generator
|
||||
soft_empty_cache()
|
||||
|
||||
@@ -777,10 +785,10 @@ class SelvaBigvganTrainer:
|
||||
comfy.model_management.unload_all_models()
|
||||
soft_empty_cache()
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
# Only move mel_converter to GPU — it's tiny and needed for training.
|
||||
# The rest of feature_utils (CLIP, synchformer, T5, VAE) stays on CPU;
|
||||
# _pregenerate_lora_mels handles its own device management for the parts
|
||||
# it needs temporarily.
|
||||
mel_converter.to(device)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
Reference in New Issue
Block a user