From e870446b0f9822f8f8fb099b81c5127258f30b16 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 02:30:53 +0200 Subject: [PATCH] fix: run BigVGAN training in a fresh thread to escape inference_mode torch.inference_mode is thread-local. ComfyUI sets it on the node-execution thread; inference_mode(False) alone is insufficient to escape it in some environments (e.g. async wrappers, lora-manager hook). A new thread always starts clean. Moved all training logic into _do_train() called via threading.Thread so every tensor is a normal autograd tensor by default. Simplified parameter cloning: clone().detach().requires_grad_(True). Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 264 ++++++++++++++++----------------- 1 file changed, 128 insertions(+), 136 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index f235d17..c3978ac 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -12,6 +12,7 @@ checkpoint so it can be loaded with SelVA BigVGAN Loader. """ import random +import threading from pathlib import Path import torch @@ -121,17 +122,15 @@ class SelvaBigvganTrainer: device = get_device() mode = model["mode"] - dtype = model["dtype"] # bf16/fp16/fp32 — must match mel_converter buffers + dtype = model["dtype"] feature_utils = model["feature_utils"] mel_converter = feature_utils.mel_converter strategy = model["strategy"] if mode == "16k": - # BigVGANVocoder wrapped inside BigVGAN — bypass the @inference_mode on the wrapper vocoder = feature_utils.tod.vocoder.vocoder sample_rate = 16_000 elif mode == "44k": - # BigVGANv2 is the vocoder directly (no wrapper); no @inference_mode decorator vocoder = feature_utils.tod.vocoder sample_rate = 44_100 else: @@ -168,7 +167,7 @@ class SelvaBigvganTrainer: wav = torchaudio.functional.resample(wav, sr, sample_rate) wav = wav.squeeze(0) # [L] if wav.shape[0] >= segment_samples: - clips.append(wav) + clips.append(wav.cpu()) else: print(f" [BigVGAN] Skip {af.name}: shorter than {segment_seconds}s", flush=True) except Exception as e: @@ -188,153 +187,146 @@ class SelvaBigvganTrainer: mel_converter.to(device) - torch.manual_seed(seed) - random.seed(seed) - - # Fixed reference segment for eval samples — always clip 0, start 0 - ref_clip = clips[0][:segment_samples].to(device, dtype) # [T] - ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel] - - def _save_sample(label): - """Vocode the reference mel and save as .wav.""" - try: - # Vocoder may have been offloaded to CPU after training — match its device. - voc_device = next(vocoder.parameters()).device - mel = ref_mel.to(voc_device) - with torch.no_grad(): - wav = vocoder(mel) # [1, 1, T] or [1, T] - if wav.dim() == 2: - wav = wav.unsqueeze(1) - wav = wav.float().cpu().clamp(-1, 1) - wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" - torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate) - print(f"[BigVGAN] Sample saved: {wav_path}", flush=True) - except Exception as e: - print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True) - - # Baseline: ground truth roundtrip before any fine-tuning - _save_sample("baseline") - pbar = comfy.utils.ProgressBar(steps) + # ----------------------------------------------------------------------- + # Run the entire training in a fresh thread. + # + # ComfyUI executes nodes inside torch.inference_mode(). Even with an inner + # inference_mode(False) context, factory functions and operations may still + # produce inference tensors in some environments (e.g. when the outer + # context is set via an async wrapper or a third-party hook). + # + # torch.inference_mode is THREAD-LOCAL. A new thread always starts with + # inference_mode disabled, so all tensor operations in the worker thread + # produce normal, autograd-compatible tensors — no flags to fight. + # ----------------------------------------------------------------------- + _result = [None] + _exc = [None] + + def _worker(): + try: + _result[0] = _do_train( + vocoder, mel_converter, clips, + device, dtype, strategy, feature_utils, + segment_samples, sample_rate, + steps, lr, batch_size, save_every, seed, + out_path, pbar, + ) + except Exception as e: + _exc[0] = e + traceback.print_exc() + + t = threading.Thread(target=_worker, daemon=True) + t.start() + t.join() + + if _exc[0] is not None: + raise _exc[0] + return (_result[0],) + + +def _do_train(vocoder, mel_converter, clips, + device, dtype, strategy, feature_utils, + segment_samples, sample_rate, + steps, lr, batch_size, save_every, seed, + out_path, pbar): + """Execute training. Called in a fresh thread — no inference_mode active.""" + import torch.nn as nn_mod + + torch.manual_seed(seed) + random.seed(seed) + + # Reference segment for eval samples — always clip 0, start 0 + ref_wav = clips[0][:segment_samples].to(device, dtype) # [T] + ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel] + + def _save_sample(label): try: - with torch.inference_mode(False): - with torch.enable_grad(): - # Vocoder parameters are inference tensors (loaded inside ComfyUI's - # inference_mode). param.data = clone() only changes storage — the - # nn.Parameter object itself still carries the inference flag. - # Replace each parameter with a fresh nn.Parameter created here - # (inside inference_mode(False)) so the object itself is normal. - # param.data.clone() of an inference tensor still produces an - # inference tensor. Use torch.zeros + copy_ to create a genuinely - # fresh normal tensor, then wrap in nn.Parameter (created here, - # inside inference_mode(False), so it is a normal parameter). - import torch.nn as nn_mod - for module in vocoder.modules(): - for pname, param in list(module._parameters.items()): - if param is not None: - fresh = torch.zeros( - param.shape, device=param.device, dtype=param.dtype - ) - fresh.copy_(param.data) - module._parameters[pname] = nn_mod.Parameter( - fresh, requires_grad=True - ) + voc_device = next(vocoder.parameters()).device + mel = ref_mel.to(voc_device) + with torch.no_grad(): + wav = vocoder(mel) + if wav.dim() == 2: + wav = wav.unsqueeze(1) + wav = wav.float().cpu().clamp(-1, 1) + wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" + torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate) + print(f"[BigVGAN] Sample saved: {wav_path}", flush=True) + except Exception as e: + print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True) - # mel_converter and its submodules (e.g. Spectrogram.window) have - # inference-tensor buffers loaded in ComfyUI's outer inference_mode. - # Must iterate .modules() — ._buffers only covers direct buffers. - for sub in mel_converter.modules(): - for bname, buf in list(sub._buffers.items()): - if buf is not None: - fresh = torch.zeros( - buf.shape, device=buf.device, dtype=buf.dtype - ) - fresh.copy_(buf) - sub._buffers[bname] = fresh + _save_sample("baseline") - optimizer = torch.optim.AdamW( - vocoder.parameters(), lr=lr, betas=(0.8, 0.99) - ) - vocoder.train() + # Replace inference-tensor parameters with fresh trainable parameters. + # (Vocoder was loaded by ComfyUI before this thread existed; its parameters + # may carry the inference flag from that context. Here we're in a clean thread + # so all new tensors are normal.) + for module in vocoder.modules(): + for pname, param in list(module._parameters.items()): + if param is not None: + fresh = param.data.clone().detach().requires_grad_(True) + module._parameters[pname] = nn_mod.Parameter(fresh) - for step in range(steps): - # Sample random batch - batch = [] - for _ in range(batch_size): - clip = random.choice(clips) - start = random.randint(0, clip.shape[0] - segment_samples) - batch.append(clip[start : start + segment_samples]) + optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) + vocoder.train() - # clips were loaded in ComfyUI's outer inference_mode, so every - # element is an inference tensor. torch.stack().clone() is still - # an inference tensor (the flag propagates through all ops). - # Use zeros+copy_ to produce a genuine normal tensor. - _stacked = torch.stack(batch).to(device, dtype) - target_flat = torch.zeros( - _stacked.shape, device=device, dtype=dtype - ) - target_flat.copy_(_stacked) - del _stacked - target_wav = target_flat.unsqueeze(1) # [B, 1, T] + try: + for step in range(steps): + # Sample random batch — clips are CPU floats, move to device + batch = [] + for _ in range(batch_size): + clip = random.choice(clips) + start = random.randint(0, clip.shape[0] - segment_samples) + batch.append(clip[start : start + segment_samples]) - # Compute target mel and guarantee it is not an inference tensor. - # Even with sanitized buffers a submodule we missed could still - # taint the output, so we always copy into a fresh tensor. - with torch.no_grad(): - _mel = mel_converter(target_flat) - target_mel = torch.empty( - _mel.shape, device=device, dtype=dtype - ) - target_mel.copy_(_mel) - del _mel + target_flat = torch.stack(batch).to(device, dtype) # [B, T] + target_wav = target_flat.unsqueeze(1) # [B, 1, T] - # Vocoder forward: mel → waveform - pred_wav = vocoder(target_mel) # [B, 1, T_wav] + with torch.no_grad(): + target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] - # Align lengths - T = min(pred_wav.shape[-1], target_wav.shape[-1]) - pred_t = pred_wav[..., :T] - target_t = target_wav[..., :T] + pred_wav = vocoder(target_mel) # [B, 1, T_wav] - # Mel reconstruction loss — no no_grad: grad must flow - # through pred_t → mel_converter → loss. - pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel'] - T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) - mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) + T = min(pred_wav.shape[-1], target_wav.shape[-1]) + pred_t = pred_wav[..., :T] + target_t = target_wav[..., :T] - # Multi-resolution STFT loss - stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) + pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel'] + T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) + mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) - loss = mel_loss + stft_loss - optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0) - optimizer.step() + stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) - pbar.update(1) + loss = mel_loss + stft_loss + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0) + optimizer.step() - if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1: - print(f"[BigVGAN] {step+1}/{steps} " - f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} " - f"total={loss.item():.4f}", flush=True) + pbar.update(1) - if (step + 1) % save_every == 0 and (step + 1) < steps: - step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" - torch.save({"generator": vocoder.state_dict()}, str(step_path)) - print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) - vocoder.eval() - _save_sample(f"step{step+1}") - vocoder.train() + if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1: + print(f"[BigVGAN] {step+1}/{steps} " + f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} " + f"total={loss.item():.4f}", flush=True) - finally: - vocoder.requires_grad_(False) - vocoder.eval() - if strategy == "offload_to_cpu": - feature_utils.to("cpu") - soft_empty_cache() + if (step + 1) % save_every == 0 and (step + 1) < steps: + step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" + torch.save({"generator": vocoder.state_dict()}, str(step_path)) + print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) + vocoder.eval() + _save_sample(f"step{step+1}") + vocoder.train() - torch.save({"generator": vocoder.state_dict()}, str(out_path)) - print(f"\n[BigVGAN] Saved: {out_path}", flush=True) - _save_sample("final") - return (str(out_path),) + finally: + vocoder.requires_grad_(False) + vocoder.eval() + if strategy == "offload_to_cpu": + feature_utils.to("cpu") + soft_empty_cache() + + torch.save({"generator": vocoder.state_dict()}, str(out_path)) + print(f"\n[BigVGAN] Saved: {out_path}", flush=True) + _save_sample("final") + return str(out_path)