fix: pre-compute text CLIP embeddings in main thread to avoid inference tensor crash
CLIP weights are inference tensors from ComfyUI loading. inference_mode is thread-local, so the worker thread can't use CLIP even with a context manager. Pre-compute all text embeddings in the main thread (where inference_mode IS active), clone+detach to normal tensors, and pass them to the worker via text_clip_cache dict. CLIP no longer needs to be on GPU during pre-generation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -402,7 +402,8 @@ def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps,
|
|||||||
|
|
||||||
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||||
sample_rate, duration, seed=42, num_steps=25,
|
sample_rate, duration, seed=42, num_steps=25,
|
||||||
cfg_strength=4.5, cache_dir=None):
|
cfg_strength=4.5, cache_dir=None,
|
||||||
|
text_clip_cache=None):
|
||||||
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||||
|
|
||||||
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
||||||
@@ -413,6 +414,10 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
If cache_dir is provided, results are cached to disk and reused when
|
If cache_dir is provided, results are cached to disk and reused when
|
||||||
generation parameters haven't changed.
|
generation parameters haven't changed.
|
||||||
|
|
||||||
|
text_clip_cache: dict mapping npz filename → pre-computed text CLIP
|
||||||
|
embedding tensor [1, seq, dim]. Pre-computed in the main thread where
|
||||||
|
inference_mode is active (CLIP weights are inference tensors).
|
||||||
|
|
||||||
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||||
"""
|
"""
|
||||||
# ── Check cache ──────────────────────────────────────────────────────────
|
# ── Check cache ──────────────────────────────────────────────────────────
|
||||||
@@ -471,32 +476,18 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — "
|
raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — "
|
||||||
"point data_dir to your LoRA training features directory")
|
"point data_dir to your LoRA training features directory")
|
||||||
|
|
||||||
# Load prompt map if available (same logic as LoRA trainer)
|
if text_clip_cache is None:
|
||||||
prompt_map = {}
|
text_clip_cache = {}
|
||||||
prompts_file = data_dir / "prompts.txt"
|
|
||||||
if prompts_file.exists():
|
|
||||||
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
|
||||||
line = line.strip()
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
if "|" in line:
|
|
||||||
fname, prompt = line.split("|", 1)
|
|
||||||
prompt_map[fname.strip()] = prompt.strip()
|
|
||||||
default_prompt = data_dir.name
|
|
||||||
|
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||||
rng = torch.Generator(device=device).manual_seed(seed)
|
rng = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
# Move only the components we need to GPU for generation:
|
# Move only tod (VAE+vocoder) to GPU for decode.
|
||||||
# - tod (VAE+vocoder) for decode
|
# CLIP is NOT needed here — text embeddings are pre-computed in the main
|
||||||
# - clip_model for encode_text_clip
|
# thread and passed via text_clip_cache.
|
||||||
tod = feature_utils.tod
|
tod = feature_utils.tod
|
||||||
tod_orig_dev = next(tod.parameters()).device
|
tod_orig_dev = next(tod.parameters()).device
|
||||||
tod.to(device)
|
tod.to(device)
|
||||||
clip_model = feature_utils.clip_model
|
|
||||||
if clip_model is not None:
|
|
||||||
clip_orig_dev = next(clip_model.parameters()).device
|
|
||||||
clip_model.to(device)
|
|
||||||
|
|
||||||
pairs = []
|
pairs = []
|
||||||
try:
|
try:
|
||||||
@@ -525,12 +516,12 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
elif sync_f.shape[1] > s_tgt:
|
elif sync_f.shape[1] > s_tgt:
|
||||||
sync_f = sync_f[:, :s_tgt, :]
|
sync_f = sync_f[:, :s_tgt, :]
|
||||||
|
|
||||||
# Text CLIP encoding
|
# Text CLIP embedding (pre-computed in main thread)
|
||||||
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
|
if npz_path.name in text_clip_cache:
|
||||||
if isinstance(prompt, np.ndarray):
|
text_clip = text_clip_cache[npz_path.name].to(device, dtype)
|
||||||
prompt = str(prompt)
|
else:
|
||||||
with torch.inference_mode():
|
print(f" [BigVGAN] No text embedding for {npz_path.name}, skipping", flush=True)
|
||||||
text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype)
|
continue
|
||||||
|
|
||||||
# Load clean audio
|
# Load clean audio
|
||||||
try:
|
try:
|
||||||
@@ -572,8 +563,6 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
tod.to(tod_orig_dev)
|
tod.to(tod_orig_dev)
|
||||||
if clip_model is not None:
|
|
||||||
clip_model.to(clip_orig_dev)
|
|
||||||
del generator
|
del generator
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
@@ -797,6 +786,45 @@ class SelvaBigvganTrainer:
|
|||||||
# _pregenerate_lora_mels handles its own device management for CLIP/tod.
|
# _pregenerate_lora_mels handles its own device management for CLIP/tod.
|
||||||
mel_converter.to(device)
|
mel_converter.to(device)
|
||||||
|
|
||||||
|
# Pre-compute text CLIP embeddings in the main thread.
|
||||||
|
# CLIP weights are inference tensors from ComfyUI loading — they only
|
||||||
|
# work in inference_mode, which is thread-local and active here but NOT
|
||||||
|
# in the worker thread. Pre-computing avoids needing CLIP on GPU in the
|
||||||
|
# worker. Results are cloned+detached so they're normal tensors.
|
||||||
|
text_clip_cache = {}
|
||||||
|
if lora_path is not None:
|
||||||
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
|
if npz_files:
|
||||||
|
prompt_map = {}
|
||||||
|
prompts_file = data_dir / "prompts.txt"
|
||||||
|
if prompts_file.exists():
|
||||||
|
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if "|" in line:
|
||||||
|
fname, prompt = line.split("|", 1)
|
||||||
|
prompt_map[fname.strip()] = prompt.strip()
|
||||||
|
default_prompt = data_dir.name
|
||||||
|
|
||||||
|
# Temporarily move CLIP to GPU for encoding
|
||||||
|
clip_model = feature_utils.clip_model
|
||||||
|
if clip_model is not None:
|
||||||
|
clip_model.to(device)
|
||||||
|
try:
|
||||||
|
for npz_path in npz_files:
|
||||||
|
data = dict(np.load(str(npz_path), allow_pickle=False))
|
||||||
|
prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt))
|
||||||
|
if isinstance(prompt, np.ndarray):
|
||||||
|
prompt = str(prompt)
|
||||||
|
tc = feature_utils.encode_text_clip([prompt])
|
||||||
|
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
|
||||||
|
finally:
|
||||||
|
if clip_model is not None:
|
||||||
|
clip_model.to("cpu")
|
||||||
|
soft_empty_cache()
|
||||||
|
print(f"[BigVGAN] Pre-encoded {len(text_clip_cache)} text CLIP embeddings", flush=True)
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
@@ -826,6 +854,7 @@ class SelvaBigvganTrainer:
|
|||||||
device, dtype, sample_rate,
|
device, dtype, sample_rate,
|
||||||
seq_cfg.duration, seed=seed,
|
seq_cfg.duration, seed=seed,
|
||||||
cache_dir=out_path.parent,
|
cache_dir=out_path.parent,
|
||||||
|
text_clip_cache=text_clip_cache,
|
||||||
)
|
)
|
||||||
if not lora_mel_pairs:
|
if not lora_mel_pairs:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
Reference in New Issue
Block a user