fix: run BigVGAN training in a fresh thread to escape inference_mode

torch.inference_mode is thread-local. ComfyUI sets it on the node-execution
thread; inference_mode(False) alone is insufficient to escape it in some
environments (e.g. async wrappers, lora-manager hook). A new thread always
starts clean. Moved all training logic into _do_train() called via
threading.Thread so every tensor is a normal autograd tensor by default.
Simplified parameter cloning: clone().detach().requires_grad_(True).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 02:30:53 +02:00
parent df63b147e9
commit e870446b0f
+67 -75
View File
@@ -12,6 +12,7 @@ checkpoint so it can be loaded with SelVA BigVGAN Loader.
"""
import random
import threading
from pathlib import Path
import torch
@@ -121,17 +122,15 @@ class SelvaBigvganTrainer:
device = get_device()
mode = model["mode"]
dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers
dtype = model["dtype"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
strategy = model["strategy"]
if mode == "16k":
# BigVGANVocoder wrapped inside BigVGAN — bypass the @inference_mode on the wrapper
vocoder = feature_utils.tod.vocoder.vocoder
sample_rate = 16_000
elif mode == "44k":
# BigVGANv2 is the vocoder directly (no wrapper); no @inference_mode decorator
vocoder = feature_utils.tod.vocoder
sample_rate = 44_100
else:
@@ -168,7 +167,7 @@ class SelvaBigvganTrainer:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0) # [L]
if wav.shape[0] >= segment_samples:
clips.append(wav)
clips.append(wav.cpu())
else:
print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True)
except Exception as e:
@@ -188,21 +187,66 @@ class SelvaBigvganTrainer:
mel_converter.to(device)
pbar = comfy.utils.ProgressBar(steps)
# -----------------------------------------------------------------------
# Run the entire training in a fresh thread.
#
# ComfyUI executes nodes inside torch.inference_mode(). Even with an inner
# inference_mode(False) context, factory functions and operations may still
# produce inference tensors in some environments (e.g. when the outer
# context is set via an async wrapper or a third-party hook).
#
# torch.inference_mode is THREAD-LOCAL. A new thread always starts with
# inference_mode disabled, so all tensor operations in the worker thread
# produce normal, autograd-compatible tensors — no flags to fight.
# -----------------------------------------------------------------------
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_train(
vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed,
out_path, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_train(vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed,
out_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active."""
import torch.nn as nn_mod
torch.manual_seed(seed)
random.seed(seed)
# Fixed reference segment for eval samples — always clip 0, start 0
ref_clip = clips[0][:segment_samples].to(device, dtype) # [T]
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
# Reference segment for eval samples — always clip 0, start 0
ref_wav = clips[0][:segment_samples].to(device, dtype) # [T]
ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel]
def _save_sample(label):
"""Vocode the reference mel and save as .wav."""
try:
# Vocoder may have been offloaded to CPU after training — match its device.
voc_device = next(vocoder.parameters()).device
mel = ref_mel.to(voc_device)
with torch.no_grad():
wav = vocoder(mel) # [1, 1, T] or [1, T]
wav = vocoder(mel)
if wav.dim() == 2:
wav = wav.unsqueeze(1)
wav = wav.float().cpu().clamp(-1, 1)
@@ -212,98 +256,46 @@ class SelvaBigvganTrainer:
except Exception as e:
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
# Baseline: ground truth roundtrip before any fine-tuning
_save_sample("baseline")
pbar = comfy.utils.ProgressBar(steps)
try:
with torch.inference_mode(False):
with torch.enable_grad():
# Vocoder parameters are inference tensors (loaded inside ComfyUI's
# inference_mode). param.data = clone() only changes storage — the
# nn.Parameter object itself still carries the inference flag.
# Replace each parameter with a fresh nn.Parameter created here
# (inside inference_mode(False)) so the object itself is normal.
# param.data.clone() of an inference tensor still produces an
# inference tensor. Use torch.zeros + copy_ to create a genuinely
# fresh normal tensor, then wrap in nn.Parameter (created here,
# inside inference_mode(False), so it is a normal parameter).
import torch.nn as nn_mod
# Replace inference-tensor parameters with fresh trainable parameters.
# (Vocoder was loaded by ComfyUI before this thread existed; its parameters
# may carry the inference flag from that context. Here we're in a clean thread
# so all new tensors are normal.)
for module in vocoder.modules():
for pname, param in list(module._parameters.items()):
if param is not None:
fresh = torch.zeros(
param.shape, device=param.device, dtype=param.dtype
)
fresh.copy_(param.data)
module._parameters[pname] = nn_mod.Parameter(
fresh, requires_grad=True
)
fresh = param.data.clone().detach().requires_grad_(True)
module._parameters[pname] = nn_mod.Parameter(fresh)
# mel_converter and its submodules (e.g. Spectrogram.window) have
# inference-tensor buffers loaded in ComfyUI's outer inference_mode.
# Must iterate .modules() — ._buffers only covers direct buffers.
for sub in mel_converter.modules():
for bname, buf in list(sub._buffers.items()):
if buf is not None:
fresh = torch.zeros(
buf.shape, device=buf.device, dtype=buf.dtype
)
fresh.copy_(buf)
sub._buffers[bname] = fresh
optimizer = torch.optim.AdamW(
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
)
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
vocoder.train()
try:
for step in range(steps):
# Sample random batch
# Sample random batch — clips are CPU floats, move to device
batch = []
for _ in range(batch_size):
clip = random.choice(clips)
start = random.randint(0, clip.shape[0] - segment_samples)
batch.append(clip[start : start + segment_samples])
# clips were loaded in ComfyUI's outer inference_mode, so every
# element is an inference tensor. torch.stack().clone() is still
# an inference tensor (the flag propagates through all ops).
# Use zeros+copy_ to produce a genuine normal tensor.
_stacked = torch.stack(batch).to(device, dtype)
target_flat = torch.zeros(
_stacked.shape, device=device, dtype=dtype
)
target_flat.copy_(_stacked)
del _stacked
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
# Compute target mel and guarantee it is not an inference tensor.
# Even with sanitized buffers a submodule we missed could still
# taint the output, so we always copy into a fresh tensor.
with torch.no_grad():
_mel = mel_converter(target_flat)
target_mel = torch.empty(
_mel.shape, device=device, dtype=dtype
)
target_mel.copy_(_mel)
del _mel
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
# Vocoder forward: mel → waveform
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
# Align lengths
T = min(pred_wav.shape[-1], target_wav.shape[-1])
pred_t = pred_wav[..., :T]
target_t = target_wav[..., :T]
# Mel reconstruction loss — no no_grad: grad must flow
# through pred_t → mel_converter → loss.
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel']
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel']
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
# Multi-resolution STFT loss
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
loss = mel_loss + stft_loss
@@ -337,4 +329,4 @@ class SelvaBigvganTrainer:
torch.save({"generator": vocoder.state_dict()}, str(out_path))
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
_save_sample("final")
return (str(out_path),)
return str(out_path)