fix: run BigVGAN training in a fresh thread to escape inference_mode
torch.inference_mode is thread-local. ComfyUI sets it on the node-execution thread; inference_mode(False) alone is insufficient to escape it in some environments (e.g. async wrappers, lora-manager hook). A new thread always starts clean. Moved all training logic into _do_train() called via threading.Thread so every tensor is a normal autograd tensor by default. Simplified parameter cloning: clone().detach().requires_grad_(True). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,6 +12,7 @@ checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -121,17 +122,15 @@ class SelvaBigvganTrainer:
|
|||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
mode = model["mode"]
|
mode = model["mode"]
|
||||||
dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers
|
dtype = model["dtype"]
|
||||||
feature_utils = model["feature_utils"]
|
feature_utils = model["feature_utils"]
|
||||||
mel_converter = feature_utils.mel_converter
|
mel_converter = feature_utils.mel_converter
|
||||||
strategy = model["strategy"]
|
strategy = model["strategy"]
|
||||||
|
|
||||||
if mode == "16k":
|
if mode == "16k":
|
||||||
# BigVGANVocoder wrapped inside BigVGAN — bypass the @inference_mode on the wrapper
|
|
||||||
vocoder = feature_utils.tod.vocoder.vocoder
|
vocoder = feature_utils.tod.vocoder.vocoder
|
||||||
sample_rate = 16_000
|
sample_rate = 16_000
|
||||||
elif mode == "44k":
|
elif mode == "44k":
|
||||||
# BigVGANv2 is the vocoder directly (no wrapper); no @inference_mode decorator
|
|
||||||
vocoder = feature_utils.tod.vocoder
|
vocoder = feature_utils.tod.vocoder
|
||||||
sample_rate = 44_100
|
sample_rate = 44_100
|
||||||
else:
|
else:
|
||||||
@@ -168,7 +167,7 @@ class SelvaBigvganTrainer:
|
|||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||||
wav = wav.squeeze(0) # [L]
|
wav = wav.squeeze(0) # [L]
|
||||||
if wav.shape[0] >= segment_samples:
|
if wav.shape[0] >= segment_samples:
|
||||||
clips.append(wav)
|
clips.append(wav.cpu())
|
||||||
else:
|
else:
|
||||||
print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True)
|
print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -188,21 +187,66 @@ class SelvaBigvganTrainer:
|
|||||||
|
|
||||||
mel_converter.to(device)
|
mel_converter.to(device)
|
||||||
|
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Run the entire training in a fresh thread.
|
||||||
|
#
|
||||||
|
# ComfyUI executes nodes inside torch.inference_mode(). Even with an inner
|
||||||
|
# inference_mode(False) context, factory functions and operations may still
|
||||||
|
# produce inference tensors in some environments (e.g. when the outer
|
||||||
|
# context is set via an async wrapper or a third-party hook).
|
||||||
|
#
|
||||||
|
# torch.inference_mode is THREAD-LOCAL. A new thread always starts with
|
||||||
|
# inference_mode disabled, so all tensor operations in the worker thread
|
||||||
|
# produce normal, autograd-compatible tensors — no flags to fight.
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
_result = [None]
|
||||||
|
_exc = [None]
|
||||||
|
|
||||||
|
def _worker():
|
||||||
|
try:
|
||||||
|
_result[0] = _do_train(
|
||||||
|
vocoder, mel_converter, clips,
|
||||||
|
device, dtype, strategy, feature_utils,
|
||||||
|
segment_samples, sample_rate,
|
||||||
|
steps, lr, batch_size, save_every, seed,
|
||||||
|
out_path, pbar,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
_exc[0] = e
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
t = threading.Thread(target=_worker, daemon=True)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
if _exc[0] is not None:
|
||||||
|
raise _exc[0]
|
||||||
|
return (_result[0],)
|
||||||
|
|
||||||
|
|
||||||
|
def _do_train(vocoder, mel_converter, clips,
|
||||||
|
device, dtype, strategy, feature_utils,
|
||||||
|
segment_samples, sample_rate,
|
||||||
|
steps, lr, batch_size, save_every, seed,
|
||||||
|
out_path, pbar):
|
||||||
|
"""Execute training. Called in a fresh thread — no inference_mode active."""
|
||||||
|
import torch.nn as nn_mod
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
# Fixed reference segment for eval samples — always clip 0, start 0
|
# Reference segment for eval samples — always clip 0, start 0
|
||||||
ref_clip = clips[0][:segment_samples].to(device, dtype) # [T]
|
ref_wav = clips[0][:segment_samples].to(device, dtype) # [T]
|
||||||
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||||
|
|
||||||
def _save_sample(label):
|
def _save_sample(label):
|
||||||
"""Vocode the reference mel and save as .wav."""
|
|
||||||
try:
|
try:
|
||||||
# Vocoder may have been offloaded to CPU after training — match its device.
|
|
||||||
voc_device = next(vocoder.parameters()).device
|
voc_device = next(vocoder.parameters()).device
|
||||||
mel = ref_mel.to(voc_device)
|
mel = ref_mel.to(voc_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav = vocoder(mel) # [1, 1, T] or [1, T]
|
wav = vocoder(mel)
|
||||||
if wav.dim() == 2:
|
if wav.dim() == 2:
|
||||||
wav = wav.unsqueeze(1)
|
wav = wav.unsqueeze(1)
|
||||||
wav = wav.float().cpu().clamp(-1, 1)
|
wav = wav.float().cpu().clamp(-1, 1)
|
||||||
@@ -212,98 +256,46 @@ class SelvaBigvganTrainer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
|
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
|
||||||
|
|
||||||
# Baseline: ground truth roundtrip before any fine-tuning
|
|
||||||
_save_sample("baseline")
|
_save_sample("baseline")
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
# Replace inference-tensor parameters with fresh trainable parameters.
|
||||||
|
# (Vocoder was loaded by ComfyUI before this thread existed; its parameters
|
||||||
try:
|
# may carry the inference flag from that context. Here we're in a clean thread
|
||||||
with torch.inference_mode(False):
|
# so all new tensors are normal.)
|
||||||
with torch.enable_grad():
|
|
||||||
# Vocoder parameters are inference tensors (loaded inside ComfyUI's
|
|
||||||
# inference_mode). param.data = clone() only changes storage — the
|
|
||||||
# nn.Parameter object itself still carries the inference flag.
|
|
||||||
# Replace each parameter with a fresh nn.Parameter created here
|
|
||||||
# (inside inference_mode(False)) so the object itself is normal.
|
|
||||||
# param.data.clone() of an inference tensor still produces an
|
|
||||||
# inference tensor. Use torch.zeros + copy_ to create a genuinely
|
|
||||||
# fresh normal tensor, then wrap in nn.Parameter (created here,
|
|
||||||
# inside inference_mode(False), so it is a normal parameter).
|
|
||||||
import torch.nn as nn_mod
|
|
||||||
for module in vocoder.modules():
|
for module in vocoder.modules():
|
||||||
for pname, param in list(module._parameters.items()):
|
for pname, param in list(module._parameters.items()):
|
||||||
if param is not None:
|
if param is not None:
|
||||||
fresh = torch.zeros(
|
fresh = param.data.clone().detach().requires_grad_(True)
|
||||||
param.shape, device=param.device, dtype=param.dtype
|
module._parameters[pname] = nn_mod.Parameter(fresh)
|
||||||
)
|
|
||||||
fresh.copy_(param.data)
|
|
||||||
module._parameters[pname] = nn_mod.Parameter(
|
|
||||||
fresh, requires_grad=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# mel_converter and its submodules (e.g. Spectrogram.window) have
|
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
||||||
# inference-tensor buffers loaded in ComfyUI's outer inference_mode.
|
|
||||||
# Must iterate .modules() — ._buffers only covers direct buffers.
|
|
||||||
for sub in mel_converter.modules():
|
|
||||||
for bname, buf in list(sub._buffers.items()):
|
|
||||||
if buf is not None:
|
|
||||||
fresh = torch.zeros(
|
|
||||||
buf.shape, device=buf.device, dtype=buf.dtype
|
|
||||||
)
|
|
||||||
fresh.copy_(buf)
|
|
||||||
sub._buffers[bname] = fresh
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
|
|
||||||
)
|
|
||||||
vocoder.train()
|
vocoder.train()
|
||||||
|
|
||||||
|
try:
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
# Sample random batch
|
# Sample random batch — clips are CPU floats, move to device
|
||||||
batch = []
|
batch = []
|
||||||
for _ in range(batch_size):
|
for _ in range(batch_size):
|
||||||
clip = random.choice(clips)
|
clip = random.choice(clips)
|
||||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||||
batch.append(clip[start : start + segment_samples])
|
batch.append(clip[start : start + segment_samples])
|
||||||
|
|
||||||
# clips were loaded in ComfyUI's outer inference_mode, so every
|
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||||
# element is an inference tensor. torch.stack().clone() is still
|
|
||||||
# an inference tensor (the flag propagates through all ops).
|
|
||||||
# Use zeros+copy_ to produce a genuine normal tensor.
|
|
||||||
_stacked = torch.stack(batch).to(device, dtype)
|
|
||||||
target_flat = torch.zeros(
|
|
||||||
_stacked.shape, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
target_flat.copy_(_stacked)
|
|
||||||
del _stacked
|
|
||||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
|
||||||
# Compute target mel and guarantee it is not an inference tensor.
|
|
||||||
# Even with sanitized buffers a submodule we missed could still
|
|
||||||
# taint the output, so we always copy into a fresh tensor.
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_mel = mel_converter(target_flat)
|
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||||
target_mel = torch.empty(
|
|
||||||
_mel.shape, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
target_mel.copy_(_mel)
|
|
||||||
del _mel
|
|
||||||
|
|
||||||
# Vocoder forward: mel → waveform
|
|
||||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||||
|
|
||||||
# Align lengths
|
|
||||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||||
pred_t = pred_wav[..., :T]
|
pred_t = pred_wav[..., :T]
|
||||||
target_t = target_wav[..., :T]
|
target_t = target_wav[..., :T]
|
||||||
|
|
||||||
# Mel reconstruction loss — no no_grad: grad must flow
|
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel']
|
||||||
# through pred_t → mel_converter → loss.
|
|
||||||
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel']
|
|
||||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||||
|
|
||||||
# Multi-resolution STFT loss
|
|
||||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||||
|
|
||||||
loss = mel_loss + stft_loss
|
loss = mel_loss + stft_loss
|
||||||
@@ -337,4 +329,4 @@ class SelvaBigvganTrainer:
|
|||||||
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
||||||
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
||||||
_save_sample("final")
|
_save_sample("final")
|
||||||
return (str(out_path),)
|
return str(out_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user