diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 4130a17..8965aad 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -27,6 +27,8 @@ BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader. """ import copy +import hashlib +import json as _json import random import threading from pathlib import Path @@ -373,9 +375,34 @@ def _find_audio_for_npz(npz_path: Path): return None +def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps, + cfg_strength, duration, sample_rate): + """Build a deterministic hash from all parameters that affect LoRA mel generation.""" + # Hash the LoRA adapter file content (not path — same file moved = same cache) + h = hashlib.sha256() + with open(lora_adapter_path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + lora_hash = h.hexdigest()[:16] + + # Hash the sorted .npz file list (names only — content is deterministic per name) + npz_names = sorted(p.name for p in Path(data_dir).glob("*.npz")) + + key_data = _json.dumps({ + "lora_hash": lora_hash, + "npz_files": npz_names, + "seed": seed, + "num_steps": num_steps, + "cfg_strength": cfg_strength, + "duration": duration, + "sample_rate": sample_rate, + }, sort_keys=True) + return hashlib.sha256(key_data.encode()).hexdigest()[:20] + + def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, sample_rate, duration, seed=42, num_steps=25, - cfg_strength=4.5): + cfg_strength=4.5, cache_dir=None): """Generate LoRA mels for all clips with matching audio in data_dir. Uses the LoRA adapter to run full ODE generation with CFG → VAE decode → @@ -383,8 +410,26 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, default (4.5) so the degraded mels the vocoder trains on are representative of what it will see at inference time. + If cache_dir is provided, results are cached to disk and reused when + generation parameters haven't changed. + Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors. """ + # ── Check cache ────────────────────────────────────────────────────────── + cache_path = None + if cache_dir is not None: + cache_key = _lora_mel_cache_key( + lora_adapter_path, data_dir, seed, num_steps, + cfg_strength, duration, sample_rate, + ) + cache_path = Path(cache_dir) / f"lora_mels_{cache_key}.pt" + if cache_path.exists(): + print(f"[BigVGAN] Loading cached LoRA mels: {cache_path.name}", flush=True) + cached = torch.load(str(cache_path), map_location="cpu", weights_only=True) + pairs = [(m, a) for m, a in zip(cached["mels"], cached["audios"])] + print(f"[BigVGAN] Loaded {len(pairs)} cached mel/audio pairs", flush=True) + return pairs + from selva_core.model.lora import apply_lora, load_lora from selva_core.model.flow_matching import FlowMatching @@ -524,6 +569,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, soft_empty_cache() print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True) + + # ── Save cache ─────────────────────────────────────────────────────────── + if cache_path is not None and pairs: + cache_path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + "mels": [m for m, _ in pairs], + "audios": [a for _, a in pairs], + }, str(cache_path)) + print(f"[BigVGAN] Cached LoRA mels: {cache_path.name}", flush=True) + return pairs @@ -756,6 +811,7 @@ class SelvaBigvganTrainer: model, data_dir, str(lora_path), device, dtype, sample_rate, seq_cfg.duration, seed=seed, + cache_dir=out_path.parent, ) if not lora_mel_pairs: raise RuntimeError(