fix: pre-compute text CLIP embeddings in main thread to avoid inference tensor crash
CLIP weights are inference tensors from ComfyUI loading. inference_mode is thread-local, so the worker thread can't use CLIP even with a context manager. Pre-compute all text embeddings in the main thread (where inference_mode IS active), clone+detach to normal tensors, and pass them to the worker via text_clip_cache dict. CLIP no longer needs to be on GPU during pre-generation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -402,7 +402,8 @@ def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps,
|
||||
|
||||
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
sample_rate, duration, seed=42, num_steps=25,
|
||||
cfg_strength=4.5, cache_dir=None):
|
||||
cfg_strength=4.5, cache_dir=None,
|
||||
text_clip_cache=None):
|
||||
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||
|
||||
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
||||
@@ -413,6 +414,10 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
If cache_dir is provided, results are cached to disk and reused when
|
||||
generation parameters haven't changed.
|
||||
|
||||
text_clip_cache: dict mapping npz filename → pre-computed text CLIP
|
||||
embedding tensor [1, seq, dim]. Pre-computed in the main thread where
|
||||
inference_mode is active (CLIP weights are inference tensors).
|
||||
|
||||
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||
"""
|
||||
# ── Check cache ──────────────────────────────────────────────────────────
|
||||
@@ -471,32 +476,18 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — "
|
||||
"point data_dir to your LoRA training features directory")
|
||||
|
||||
# Load prompt map if available (same logic as LoRA trainer)
|
||||
prompt_map = {}
|
||||
prompts_file = data_dir / "prompts.txt"
|
||||
if prompts_file.exists():
|
||||
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "|" in line:
|
||||
fname, prompt = line.split("|", 1)
|
||||
prompt_map[fname.strip()] = prompt.strip()
|
||||
default_prompt = data_dir.name
|
||||
if text_clip_cache is None:
|
||||
text_clip_cache = {}
|
||||
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||
rng = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
# Move only the components we need to GPU for generation:
|
||||
# - tod (VAE+vocoder) for decode
|
||||
# - clip_model for encode_text_clip
|
||||
# Move only tod (VAE+vocoder) to GPU for decode.
|
||||
# CLIP is NOT needed here — text embeddings are pre-computed in the main
|
||||
# thread and passed via text_clip_cache.
|
||||
tod = feature_utils.tod
|
||||
tod_orig_dev = next(tod.parameters()).device
|
||||
tod.to(device)
|
||||
clip_model = feature_utils.clip_model
|
||||
if clip_model is not None:
|
||||
clip_orig_dev = next(clip_model.parameters()).device
|
||||
clip_model.to(device)
|
||||
|
||||
pairs = []
|
||||
try:
|
||||
@@ -525,12 +516,12 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
elif sync_f.shape[1] > s_tgt:
|
||||
sync_f = sync_f[:, :s_tgt, :]
|
||||
|
||||
# Text CLIP encoding
|
||||
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
|
||||
if isinstance(prompt, np.ndarray):
|
||||
prompt = str(prompt)
|
||||
with torch.inference_mode():
|
||||
text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype)
|
||||
# Text CLIP embedding (pre-computed in main thread)
|
||||
if npz_path.name in text_clip_cache:
|
||||
text_clip = text_clip_cache[npz_path.name].to(device, dtype)
|
||||
else:
|
||||
print(f" [BigVGAN] No text embedding for {npz_path.name}, skipping", flush=True)
|
||||
continue
|
||||
|
||||
# Load clean audio
|
||||
try:
|
||||
@@ -572,8 +563,6 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
|
||||
finally:
|
||||
tod.to(tod_orig_dev)
|
||||
if clip_model is not None:
|
||||
clip_model.to(clip_orig_dev)
|
||||
del generator
|
||||
soft_empty_cache()
|
||||
|
||||
@@ -797,6 +786,45 @@ class SelvaBigvganTrainer:
|
||||
# _pregenerate_lora_mels handles its own device management for CLIP/tod.
|
||||
mel_converter.to(device)
|
||||
|
||||
# Pre-compute text CLIP embeddings in the main thread.
|
||||
# CLIP weights are inference tensors from ComfyUI loading — they only
|
||||
# work in inference_mode, which is thread-local and active here but NOT
|
||||
# in the worker thread. Pre-computing avoids needing CLIP on GPU in the
|
||||
# worker. Results are cloned+detached so they're normal tensors.
|
||||
text_clip_cache = {}
|
||||
if lora_path is not None:
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
if npz_files:
|
||||
prompt_map = {}
|
||||
prompts_file = data_dir / "prompts.txt"
|
||||
if prompts_file.exists():
|
||||
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "|" in line:
|
||||
fname, prompt = line.split("|", 1)
|
||||
prompt_map[fname.strip()] = prompt.strip()
|
||||
default_prompt = data_dir.name
|
||||
|
||||
# Temporarily move CLIP to GPU for encoding
|
||||
clip_model = feature_utils.clip_model
|
||||
if clip_model is not None:
|
||||
clip_model.to(device)
|
||||
try:
|
||||
for npz_path in npz_files:
|
||||
data = dict(np.load(str(npz_path), allow_pickle=False))
|
||||
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
|
||||
if isinstance(prompt, np.ndarray):
|
||||
prompt = str(prompt)
|
||||
tc = feature_utils.encode_text_clip([prompt])
|
||||
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
|
||||
finally:
|
||||
if clip_model is not None:
|
||||
clip_model.to("cpu")
|
||||
soft_empty_cache()
|
||||
print(f"[BigVGAN] Pre-encoded {len(text_clip_cache)} text CLIP embeddings", flush=True)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -826,6 +854,7 @@ class SelvaBigvganTrainer:
|
||||
device, dtype, sample_rate,
|
||||
seq_cfg.duration, seed=seed,
|
||||
cache_dir=out_path.parent,
|
||||
text_clip_cache=text_clip_cache,
|
||||
)
|
||||
if not lora_mel_pairs:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user