fix: correct MRD channel width to 128 and unload models before training
Two bugs: 1. _DiscriminatorR used channels=32 but the BigVGAN pretrained discriminator checkpoint has channels=128. All convs in _DiscriminatorR now use 128, matching the checkpoint architecture so state_dict loads without error. 2. BigVGAN trainer OOM: SelVA generator and other ComfyUI models remain in VRAM during training (~90 GiB used). Add unload_all_models() + cache flush before the training loop to reclaim VRAM headroom. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -35,6 +35,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||||
@@ -105,13 +106,13 @@ class _DiscriminatorR(nn.Module):
|
|||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
norm = weight_norm
|
norm = weight_norm
|
||||||
self.convs = nn.ModuleList([
|
self.convs = nn.ModuleList([
|
||||||
norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
norm(nn.Conv2d(1, 128, (3, 9), padding=(1, 4))),
|
||||||
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
norm(nn.Conv2d(128, 128, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
norm(nn.Conv2d(128, 128, (3, 3), padding=(1, 1))),
|
||||||
])
|
])
|
||||||
self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
self.conv_post = norm(nn.Conv2d(128, 1, (3, 3), padding=(1, 1)))
|
||||||
|
|
||||||
def spectrogram(self, x):
|
def spectrogram(self, x):
|
||||||
"""x: [B, 1, T] → [B, 1, freq, time]"""
|
"""x: [B, 1, T] → [B, 1, freq, time]"""
|
||||||
@@ -408,6 +409,11 @@ class SelvaBigvganTrainer:
|
|||||||
f"segment={segment_seconds}s steps={steps} lr={lr} "
|
f"segment={segment_seconds}s steps={steps} lr={lr} "
|
||||||
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
|
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
|
||||||
|
|
||||||
|
# Unload all other ComfyUI models (SelVA generator, etc.) to free VRAM
|
||||||
|
# before starting training. BigVGAN + discriminator need the headroom.
|
||||||
|
comfy.model_management.unload_all_models()
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
if strategy == "offload_to_cpu":
|
||||||
feature_utils.to(device)
|
feature_utils.to(device)
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user