fix: clone inference tensors at thread entry to strip the inference flag
torch.inference_mode is thread-local, but the inference flag lives on the tensor object. Operations on inference tensors always propagate it, even in a clean thread. The only escape is .clone() called outside inference_mode. At thread entry (inference_mode disabled): clone clips and mel_converter buffers to get clean normal tensors before any training computation. Vocoder parameter clone() also now works correctly in this thread context. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -231,9 +231,28 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
segment_samples, sample_rate,
|
segment_samples, sample_rate,
|
||||||
steps, lr, batch_size, save_every, seed,
|
steps, lr, batch_size, save_every, seed,
|
||||||
out_path, pbar):
|
out_path, pbar):
|
||||||
"""Execute training. Called in a fresh thread — no inference_mode active."""
|
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||||
|
|
||||||
|
Even though inference_mode is off here, tensors created in the calling
|
||||||
|
thread's inference_mode carry the inference flag on the object itself.
|
||||||
|
Operations on inference tensors produce inference tensors regardless of
|
||||||
|
the current context. The ONLY way to strip the flag is to call .clone()
|
||||||
|
from outside inference_mode — which is exactly where we are now.
|
||||||
|
"""
|
||||||
import torch.nn as nn_mod
|
import torch.nn as nn_mod
|
||||||
|
|
||||||
|
# ── Strip inference flag from all inputs that came from the main thread ──
|
||||||
|
# 1. Audio clips (loaded in ComfyUI's inference_mode).
|
||||||
|
clips = [c.clone() for c in clips]
|
||||||
|
|
||||||
|
# 2. mel_converter buffers (mel_basis, hann_window) — same origin.
|
||||||
|
for name, buf in list(mel_converter._buffers.items()):
|
||||||
|
if buf is not None:
|
||||||
|
mel_converter._buffers[name] = buf.clone()
|
||||||
|
|
||||||
|
# 3. Vocoder parameters are handled below with clone().detach().
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user