fix: pre-compute text CLIP embeddings in main thread to avoid inference tensor crash

CLIP weights are inference tensors from ComfyUI loading. inference_mode
is thread-local, so the worker thread can't use CLIP even with a context
manager. Pre-compute all text embeddings in the main thread (where
inference_mode IS active), clone+detach to normal tensors, and pass them
to the worker via text_clip_cache dict. CLIP no longer needs to be on
GPU during pre-generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 01:19:44 +02:00
parent 32e5344ea2
commit f8d4d77b0d
+57 -28
View File
@@ -402,7 +402,8 @@ def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps,
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
sample_rate, duration, seed=42, num_steps=25,
cfg_strength=4.5, cache_dir=None):
cfg_strength=4.5, cache_dir=None,
text_clip_cache=None):
"""Generate LoRA mels for all clips with matching audio in data_dir.
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
@@ -413,6 +414,10 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
If cache_dir is provided, results are cached to disk and reused when
generation parameters haven't changed.
text_clip_cache: dict mapping npz filename → pre-computed text CLIP
embedding tensor [1, seq, dim]. Pre-computed in the main thread where
inference_mode is active (CLIP weights are inference tensors).
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
"""
# ── Check cache ──────────────────────────────────────────────────────────
@@ -471,32 +476,18 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
raise ValueError(f"[BigVGAN] No .npz files in {data_dir}"
"point data_dir to your LoRA training features directory")
# Load prompt map if available (same logic as LoRA trainer)
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
if text_clip_cache is None:
text_clip_cache = {}
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
rng = torch.Generator(device=device).manual_seed(seed)
# Move only the components we need to GPU for generation:
# - tod (VAE+vocoder) for decode
# - clip_model for encode_text_clip
# Move only tod (VAE+vocoder) to GPU for decode.
# CLIP is NOT needed here — text embeddings are pre-computed in the main
# thread and passed via text_clip_cache.
tod = feature_utils.tod
tod_orig_dev = next(tod.parameters()).device
tod.to(device)
clip_model = feature_utils.clip_model
if clip_model is not None:
clip_orig_dev = next(clip_model.parameters()).device
clip_model.to(device)
pairs = []
try:
@@ -525,12 +516,12 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
# Text CLIP encoding
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
with torch.inference_mode():
text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype)
# Text CLIP embedding (pre-computed in main thread)
if npz_path.name in text_clip_cache:
text_clip = text_clip_cache[npz_path.name].to(device, dtype)
else:
print(f" [BigVGAN] No text embedding for {npz_path.name}, skipping", flush=True)
continue
# Load clean audio
try:
@@ -572,8 +563,6 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
finally:
tod.to(tod_orig_dev)
if clip_model is not None:
clip_model.to(clip_orig_dev)
del generator
soft_empty_cache()
@@ -797,6 +786,45 @@ class SelvaBigvganTrainer:
# _pregenerate_lora_mels handles its own device management for CLIP/tod.
mel_converter.to(device)
# Pre-compute text CLIP embeddings in the main thread.
# CLIP weights are inference tensors from ComfyUI loading — they only
# work in inference_mode, which is thread-local and active here but NOT
# in the worker thread. Pre-computing avoids needing CLIP on GPU in the
# worker. Results are cloned+detached so they're normal tensors.
text_clip_cache = {}
if lora_path is not None:
npz_files = sorted(data_dir.glob("*.npz"))
if npz_files:
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
# Temporarily move CLIP to GPU for encoding
clip_model = feature_utils.clip_model
if clip_model is not None:
clip_model.to(device)
try:
for npz_path in npz_files:
data = dict(np.load(str(npz_path), allow_pickle=False))
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
tc = feature_utils.encode_text_clip([prompt])
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
finally:
if clip_model is not None:
clip_model.to("cpu")
soft_empty_cache()
print(f"[BigVGAN] Pre-encoded {len(text_clip_cache)} text CLIP embeddings", flush=True)
pbar = comfy.utils.ProgressBar(steps)
# -----------------------------------------------------------------------
@@ -826,6 +854,7 @@ class SelvaBigvganTrainer:
device, dtype, sample_rate,
seq_cfg.duration, seed=seed,
cache_dir=out_path.parent,
text_clip_cache=text_clip_cache,
)
if not lora_mel_pairs:
raise RuntimeError(