fix: run BigVGAN training in a fresh thread to escape inference_mode
torch.inference_mode is thread-local. ComfyUI sets it on the node-execution thread; inference_mode(False) alone is insufficient to escape it in some environments (e.g. async wrappers, lora-manager hook). A new thread always starts clean. Moved all training logic into _do_train() called via threading.Thread so every tensor is a normal autograd tensor by default. Simplified parameter cloning: clone().detach().requires_grad_(True). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+128
-136
@@ -12,6 +12,7 @@ checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
||||
"""
|
||||
|
||||
import random
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -121,17 +122,15 @@ class SelvaBigvganTrainer:
|
||||
|
||||
device = get_device()
|
||||
mode = model["mode"]
|
||||
dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers
|
||||
dtype = model["dtype"]
|
||||
feature_utils = model["feature_utils"]
|
||||
mel_converter = feature_utils.mel_converter
|
||||
strategy = model["strategy"]
|
||||
|
||||
if mode == "16k":
|
||||
# BigVGANVocoder wrapped inside BigVGAN — bypass the @inference_mode on the wrapper
|
||||
vocoder = feature_utils.tod.vocoder.vocoder
|
||||
sample_rate = 16_000
|
||||
elif mode == "44k":
|
||||
# BigVGANv2 is the vocoder directly (no wrapper); no @inference_mode decorator
|
||||
vocoder = feature_utils.tod.vocoder
|
||||
sample_rate = 44_100
|
||||
else:
|
||||
@@ -168,7 +167,7 @@ class SelvaBigvganTrainer:
|
||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||
wav = wav.squeeze(0) # [L]
|
||||
if wav.shape[0] >= segment_samples:
|
||||
clips.append(wav)
|
||||
clips.append(wav.cpu())
|
||||
else:
|
||||
print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True)
|
||||
except Exception as e:
|
||||
@@ -188,153 +187,146 @@ class SelvaBigvganTrainer:
|
||||
|
||||
mel_converter.to(device)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Fixed reference segment for eval samples — always clip 0, start 0
|
||||
ref_clip = clips[0][:segment_samples].to(device, dtype) # [T]
|
||||
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
|
||||
def _save_sample(label):
|
||||
"""Vocode the reference mel and save as .wav."""
|
||||
try:
|
||||
# Vocoder may have been offloaded to CPU after training — match its device.
|
||||
voc_device = next(vocoder.parameters()).device
|
||||
mel = ref_mel.to(voc_device)
|
||||
with torch.no_grad():
|
||||
wav = vocoder(mel) # [1, 1, T] or [1, T]
|
||||
if wav.dim() == 2:
|
||||
wav = wav.unsqueeze(1)
|
||||
wav = wav.float().cpu().clamp(-1, 1)
|
||||
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
|
||||
torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate)
|
||||
print(f"[BigVGAN] Sample saved: {wav_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
|
||||
|
||||
# Baseline: ground truth roundtrip before any fine-tuning
|
||||
_save_sample("baseline")
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Run the entire training in a fresh thread.
|
||||
#
|
||||
# ComfyUI executes nodes inside torch.inference_mode(). Even with an inner
|
||||
# inference_mode(False) context, factory functions and operations may still
|
||||
# produce inference tensors in some environments (e.g. when the outer
|
||||
# context is set via an async wrapper or a third-party hook).
|
||||
#
|
||||
# torch.inference_mode is THREAD-LOCAL. A new thread always starts with
|
||||
# inference_mode disabled, so all tensor operations in the worker thread
|
||||
# produce normal, autograd-compatible tensors — no flags to fight.
|
||||
# -----------------------------------------------------------------------
|
||||
_result = [None]
|
||||
_exc = [None]
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
_result[0] = _do_train(
|
||||
vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
steps, lr, batch_size, save_every, seed,
|
||||
out_path, pbar,
|
||||
)
|
||||
except Exception as e:
|
||||
_exc[0] = e
|
||||
traceback.print_exc()
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if _exc[0] is not None:
|
||||
raise _exc[0]
|
||||
return (_result[0],)
|
||||
|
||||
|
||||
def _do_train(vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
steps, lr, batch_size, save_every, seed,
|
||||
out_path, pbar):
|
||||
"""Execute training. Called in a fresh thread — no inference_mode active."""
|
||||
import torch.nn as nn_mod
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Reference segment for eval samples — always clip 0, start 0
|
||||
ref_wav = clips[0][:segment_samples].to(device, dtype) # [T]
|
||||
ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
|
||||
def _save_sample(label):
|
||||
try:
|
||||
with torch.inference_mode(False):
|
||||
with torch.enable_grad():
|
||||
# Vocoder parameters are inference tensors (loaded inside ComfyUI's
|
||||
# inference_mode). param.data = clone() only changes storage — the
|
||||
# nn.Parameter object itself still carries the inference flag.
|
||||
# Replace each parameter with a fresh nn.Parameter created here
|
||||
# (inside inference_mode(False)) so the object itself is normal.
|
||||
# param.data.clone() of an inference tensor still produces an
|
||||
# inference tensor. Use torch.zeros + copy_ to create a genuinely
|
||||
# fresh normal tensor, then wrap in nn.Parameter (created here,
|
||||
# inside inference_mode(False), so it is a normal parameter).
|
||||
import torch.nn as nn_mod
|
||||
for module in vocoder.modules():
|
||||
for pname, param in list(module._parameters.items()):
|
||||
if param is not None:
|
||||
fresh = torch.zeros(
|
||||
param.shape, device=param.device, dtype=param.dtype
|
||||
)
|
||||
fresh.copy_(param.data)
|
||||
module._parameters[pname] = nn_mod.Parameter(
|
||||
fresh, requires_grad=True
|
||||
)
|
||||
voc_device = next(vocoder.parameters()).device
|
||||
mel = ref_mel.to(voc_device)
|
||||
with torch.no_grad():
|
||||
wav = vocoder(mel)
|
||||
if wav.dim() == 2:
|
||||
wav = wav.unsqueeze(1)
|
||||
wav = wav.float().cpu().clamp(-1, 1)
|
||||
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
|
||||
torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate)
|
||||
print(f"[BigVGAN] Sample saved: {wav_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
|
||||
|
||||
# mel_converter and its submodules (e.g. Spectrogram.window) have
|
||||
# inference-tensor buffers loaded in ComfyUI's outer inference_mode.
|
||||
# Must iterate .modules() — ._buffers only covers direct buffers.
|
||||
for sub in mel_converter.modules():
|
||||
for bname, buf in list(sub._buffers.items()):
|
||||
if buf is not None:
|
||||
fresh = torch.zeros(
|
||||
buf.shape, device=buf.device, dtype=buf.dtype
|
||||
)
|
||||
fresh.copy_(buf)
|
||||
sub._buffers[bname] = fresh
|
||||
_save_sample("baseline")
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
|
||||
)
|
||||
vocoder.train()
|
||||
# Replace inference-tensor parameters with fresh trainable parameters.
|
||||
# (Vocoder was loaded by ComfyUI before this thread existed; its parameters
|
||||
# may carry the inference flag from that context. Here we're in a clean thread
|
||||
# so all new tensors are normal.)
|
||||
for module in vocoder.modules():
|
||||
for pname, param in list(module._parameters.items()):
|
||||
if param is not None:
|
||||
fresh = param.data.clone().detach().requires_grad_(True)
|
||||
module._parameters[pname] = nn_mod.Parameter(fresh)
|
||||
|
||||
for step in range(steps):
|
||||
# Sample random batch
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
clip = random.choice(clips)
|
||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||
batch.append(clip[start : start + segment_samples])
|
||||
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
||||
vocoder.train()
|
||||
|
||||
# clips were loaded in ComfyUI's outer inference_mode, so every
|
||||
# element is an inference tensor. torch.stack().clone() is still
|
||||
# an inference tensor (the flag propagates through all ops).
|
||||
# Use zeros+copy_ to produce a genuine normal tensor.
|
||||
_stacked = torch.stack(batch).to(device, dtype)
|
||||
target_flat = torch.zeros(
|
||||
_stacked.shape, device=device, dtype=dtype
|
||||
)
|
||||
target_flat.copy_(_stacked)
|
||||
del _stacked
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
try:
|
||||
for step in range(steps):
|
||||
# Sample random batch — clips are CPU floats, move to device
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
clip = random.choice(clips)
|
||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||
batch.append(clip[start : start + segment_samples])
|
||||
|
||||
# Compute target mel and guarantee it is not an inference tensor.
|
||||
# Even with sanitized buffers a submodule we missed could still
|
||||
# taint the output, so we always copy into a fresh tensor.
|
||||
with torch.no_grad():
|
||||
_mel = mel_converter(target_flat)
|
||||
target_mel = torch.empty(
|
||||
_mel.shape, device=device, dtype=dtype
|
||||
)
|
||||
target_mel.copy_(_mel)
|
||||
del _mel
|
||||
target_flat = torch.stack(batch).to(device, dtype) # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
# Vocoder forward: mel → waveform
|
||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||
with torch.no_grad():
|
||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
|
||||
# Align lengths
|
||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||
pred_t = pred_wav[..., :T]
|
||||
target_t = target_wav[..., :T]
|
||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||
|
||||
# Mel reconstruction loss — no no_grad: grad must flow
|
||||
# through pred_t → mel_converter → loss.
|
||||
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel']
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||
pred_t = pred_wav[..., :T]
|
||||
target_t = target_wav[..., :T]
|
||||
|
||||
# Multi-resolution STFT loss
|
||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel']
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
|
||||
loss = mel_loss + stft_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||
|
||||
pbar.update(1)
|
||||
loss = mel_loss + stft_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
|
||||
if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1:
|
||||
print(f"[BigVGAN] {step+1}/{steps} "
|
||||
f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} "
|
||||
f"total={loss.item():.4f}", flush=True)
|
||||
pbar.update(1)
|
||||
|
||||
if (step + 1) % save_every == 0 and (step + 1) < steps:
|
||||
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
||||
torch.save({"generator": vocoder.state_dict()}, str(step_path))
|
||||
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
||||
vocoder.eval()
|
||||
_save_sample(f"step{step+1}")
|
||||
vocoder.train()
|
||||
if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1:
|
||||
print(f"[BigVGAN] {step+1}/{steps} "
|
||||
f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} "
|
||||
f"total={loss.item():.4f}", flush=True)
|
||||
|
||||
finally:
|
||||
vocoder.requires_grad_(False)
|
||||
vocoder.eval()
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to("cpu")
|
||||
soft_empty_cache()
|
||||
if (step + 1) % save_every == 0 and (step + 1) < steps:
|
||||
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
||||
torch.save({"generator": vocoder.state_dict()}, str(step_path))
|
||||
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
||||
vocoder.eval()
|
||||
_save_sample(f"step{step+1}")
|
||||
vocoder.train()
|
||||
|
||||
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
||||
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
||||
_save_sample("final")
|
||||
return (str(out_path),)
|
||||
finally:
|
||||
vocoder.requires_grad_(False)
|
||||
vocoder.eval()
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to("cpu")
|
||||
soft_empty_cache()
|
||||
|
||||
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
||||
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
||||
_save_sample("final")
|
||||
return str(out_path)
|
||||
|
||||
Reference in New Issue
Block a user