fix: correct MRD channel width to 128 and unload models before training

Two bugs:

1. _DiscriminatorR used channels=32 but the BigVGAN pretrained discriminator
   checkpoint has channels=128. All convs in _DiscriminatorR now use 128,
   matching the checkpoint architecture so state_dict loads without error.

2. BigVGAN trainer OOM: SelVA generator and other ComfyUI models remain in
   VRAM during training (~90 GiB used). Add unload_all_models() + cache
   flush before the training loop to reclaim VRAM headroom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 13:40:01 +02:00
parent 357b875e5e
commit eece79ccae
+12 -6
View File
@@ -35,6 +35,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import comfy.utils import comfy.utils
import comfy.model_management
import folder_paths import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
@@ -105,13 +106,13 @@ class _DiscriminatorR(nn.Module):
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm norm = weight_norm
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), norm(nn.Conv2d(1, 128, (3, 9), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), norm(nn.Conv2d(128, 128, (3, 3), padding=(1, 1))),
]) ])
self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) self.conv_post = norm(nn.Conv2d(128, 1, (3, 3), padding=(1, 1)))
def spectrogram(self, x): def spectrogram(self, x):
"""x: [B, 1, T] → [B, 1, freq, time]""" """x: [B, 1, T] → [B, 1, freq, time]"""
@@ -408,6 +409,11 @@ class SelvaBigvganTrainer:
f"segment={segment_seconds}s steps={steps} lr={lr} " f"segment={segment_seconds}s steps={steps} lr={lr} "
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True) f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
# Unload all other ComfyUI models (SelVA generator, etc.) to free VRAM
# before starting training. BigVGAN + discriminator need the headroom.
comfy.model_management.unload_all_models()
soft_empty_cache()
if strategy == "offload_to_cpu": if strategy == "offload_to_cpu":
feature_utils.to(device) feature_utils.to(device)
soft_empty_cache() soft_empty_cache()