fix: clone inference tensors at thread entry to strip the inference flag

torch.inference_mode is thread-local, but the inference flag lives on the
tensor object. Operations on inference tensors always propagate it, even in
a clean thread. The only escape is .clone() called outside inference_mode.
At thread entry (inference_mode disabled): clone clips and mel_converter
buffers to get clean normal tensors before any training computation.
Vocoder parameter clone() also now works correctly in this thread context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 02:35:48 +02:00
parent e870446b0f
commit 78f8aa98ad
+20 -1
View File
@@ -231,9 +231,28 @@ def _do_train(vocoder, mel_converter, clips,
segment_samples, sample_rate, segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed, steps, lr, batch_size, save_every, seed,
out_path, pbar): out_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active.""" """Execute training. Called in a fresh thread — no inference_mode active.
Even though inference_mode is off here, tensors created in the calling
thread's inference_mode carry the inference flag on the object itself.
Operations on inference tensors produce inference tensors regardless of
the current context. The ONLY way to strip the flag is to call .clone()
from outside inference_mode — which is exactly where we are now.
"""
import torch.nn as nn_mod import torch.nn as nn_mod
# ── Strip inference flag from all inputs that came from the main thread ──
# 1. Audio clips (loaded in ComfyUI's inference_mode).
clips = [c.clone() for c in clips]
# 2. mel_converter buffers (mel_basis, hann_window) — same origin.
for name, buf in list(mel_converter._buffers.items()):
if buf is not None:
mel_converter._buffers[name] = buf.clone()
# 3. Vocoder parameters are handled below with clone().detach().
# ─────────────────────────────────────────────────────────────────────────
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)