diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 3aa2d67..16ab5a1 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -35,6 +35,7 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio import comfy.utils +import comfy.model_management import folder_paths from .utils import SELVA_CATEGORY, get_device, soft_empty_cache @@ -105,13 +106,13 @@ class _DiscriminatorR(nn.Module): from torch.nn.utils.parametrizations import weight_norm norm = weight_norm self.convs = nn.ModuleList([ - norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), - norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), - norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), - norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), - norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), + norm(nn.Conv2d(1, 128, (3, 9), padding=(1, 4))), + norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))), + norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))), + norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))), + norm(nn.Conv2d(128, 128, (3, 3), padding=(1, 1))), ]) - self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) + self.conv_post = norm(nn.Conv2d(128, 1, (3, 3), padding=(1, 1))) def spectrogram(self, x): """x: [B, 1, T] → [B, 1, freq, time]""" @@ -408,6 +409,11 @@ class SelvaBigvganTrainer: f"segment={segment_seconds}s steps={steps} lr={lr} " f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True) + # Unload all other ComfyUI models (SelVA generator, etc.) to free VRAM + # before starting training. BigVGAN + discriminator need the headroom. + comfy.model_management.unload_all_models() + soft_empty_cache() + if strategy == "offload_to_cpu": feature_utils.to(device) soft_empty_cache()