diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index fa096db..b38164f 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -402,7 +402,8 @@ def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps, def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, sample_rate, duration, seed=42, num_steps=25, - cfg_strength=4.5, cache_dir=None): + cfg_strength=4.5, cache_dir=None, + text_clip_cache=None): """Generate LoRA mels for all clips with matching audio in data_dir. Uses the LoRA adapter to run full ODE generation with CFG → VAE decode → @@ -413,6 +414,10 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, If cache_dir is provided, results are cached to disk and reused when generation parameters haven't changed. + text_clip_cache: dict mapping npz filename → pre-computed text CLIP + embedding tensor [1, seq, dim]. Pre-computed in the main thread where + inference_mode is active (CLIP weights are inference tensors). + Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors. """ # ── Check cache ────────────────────────────────────────────────────────── @@ -471,32 +476,18 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, raise ValueError(f"[BigVGAN] No .npz files in {data_dir} — " "point data_dir to your LoRA training features directory") - # Load prompt map if available (same logic as LoRA trainer) - prompt_map = {} - prompts_file = data_dir / "prompts.txt" - if prompts_file.exists(): - for line in prompts_file.read_text(encoding="utf-8").splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - if "|" in line: - fname, prompt = line.split("|", 1) - prompt_map[fname.strip()] = prompt.strip() - default_prompt = data_dir.name + if text_clip_cache is None: + text_clip_cache = {} fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) rng = torch.Generator(device=device).manual_seed(seed) - # Move only the components we need to GPU for generation: - # - tod (VAE+vocoder) for decode - # - clip_model for encode_text_clip + # Move only tod (VAE+vocoder) to GPU for decode. + # CLIP is NOT needed here — text embeddings are pre-computed in the main + # thread and passed via text_clip_cache. tod = feature_utils.tod tod_orig_dev = next(tod.parameters()).device tod.to(device) - clip_model = feature_utils.clip_model - if clip_model is not None: - clip_orig_dev = next(clip_model.parameters()).device - clip_model.to(device) pairs = [] try: @@ -525,12 +516,12 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, elif sync_f.shape[1] > s_tgt: sync_f = sync_f[:, :s_tgt, :] - # Text CLIP encoding - prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt)) - if isinstance(prompt, np.ndarray): - prompt = str(prompt) - with torch.inference_mode(): - text_clip = feature_utils.encode_text_clip([prompt]).to(device, dtype) + # Text CLIP embedding (pre-computed in main thread) + if npz_path.name in text_clip_cache: + text_clip = text_clip_cache[npz_path.name].to(device, dtype) + else: + print(f" [BigVGAN] No text embedding for {npz_path.name}, skipping", flush=True) + continue # Load clean audio try: @@ -572,8 +563,6 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, finally: tod.to(tod_orig_dev) - if clip_model is not None: - clip_model.to(clip_orig_dev) del generator soft_empty_cache() @@ -797,6 +786,45 @@ class SelvaBigvganTrainer: # _pregenerate_lora_mels handles its own device management for CLIP/tod. mel_converter.to(device) + # Pre-compute text CLIP embeddings in the main thread. + # CLIP weights are inference tensors from ComfyUI loading — they only + # work in inference_mode, which is thread-local and active here but NOT + # in the worker thread. Pre-computing avoids needing CLIP on GPU in the + # worker. Results are cloned+detached so they're normal tensors. + text_clip_cache = {} + if lora_path is not None: + npz_files = sorted(data_dir.glob("*.npz")) + if npz_files: + prompt_map = {} + prompts_file = data_dir / "prompts.txt" + if prompts_file.exists(): + for line in prompts_file.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "|" in line: + fname, prompt = line.split("|", 1) + prompt_map[fname.strip()] = prompt.strip() + default_prompt = data_dir.name + + # Temporarily move CLIP to GPU for encoding + clip_model = feature_utils.clip_model + if clip_model is not None: + clip_model.to(device) + try: + for npz_path in npz_files: + data = dict(np.load(str(npz_path), allow_pickle=False)) + prompt = prompt_map.get(npz_path.name, data.get("prompt", default_prompt)) + if isinstance(prompt, np.ndarray): + prompt = str(prompt) + tc = feature_utils.encode_text_clip([prompt]) + text_clip_cache[npz_path.name] = tc.clone().detach().cpu() + finally: + if clip_model is not None: + clip_model.to("cpu") + soft_empty_cache() + print(f"[BigVGAN] Pre-encoded {len(text_clip_cache)} text CLIP embeddings", flush=True) + pbar = comfy.utils.ProgressBar(steps) # ----------------------------------------------------------------------- @@ -826,6 +854,7 @@ class SelvaBigvganTrainer: device, dtype, sample_rate, seq_cfg.duration, seed=seed, cache_dir=out_path.parent, + text_clip_cache=text_clip_cache, ) if not lora_mel_pairs: raise RuntimeError(