feat: cache pre-generated LoRA mels to disk for reuse

LoRA mel pre-generation runs a full ODE+CFG for every clip, which is slow.
Cache results to a .pt file next to the output, keyed by a SHA-256 hash
of the LoRA adapter content + generation parameters (seed, steps, CFG,
duration, sample rate, npz file list). Automatically reused on subsequent
runs when parameters haven't changed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:30:20 +02:00
parent 0854bd2638
commit 4e6cc4d519
+57 -1
View File
@@ -27,6 +27,8 @@ BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
""" """
import copy import copy
import hashlib
import json as _json
import random import random
import threading import threading
from pathlib import Path from pathlib import Path
@@ -373,9 +375,34 @@ def _find_audio_for_npz(npz_path: Path):
return None return None
def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps,
cfg_strength, duration, sample_rate):
"""Build a deterministic hash from all parameters that affect LoRA mel generation."""
# Hash the LoRA adapter file content (not path — same file moved = same cache)
h = hashlib.sha256()
with open(lora_adapter_path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
lora_hash = h.hexdigest()[:16]
# Hash the sorted .npz file list (names only — content is deterministic per name)
npz_names = sorted(p.name for p in Path(data_dir).glob("*.npz"))
key_data = _json.dumps({
"lora_hash": lora_hash,
"npz_files": npz_names,
"seed": seed,
"num_steps": num_steps,
"cfg_strength": cfg_strength,
"duration": duration,
"sample_rate": sample_rate,
}, sort_keys=True)
return hashlib.sha256(key_data.encode()).hexdigest()[:20]
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype, def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
sample_rate, duration, seed=42, num_steps=25, sample_rate, duration, seed=42, num_steps=25,
cfg_strength=4.5): cfg_strength=4.5, cache_dir=None):
"""Generate LoRA mels for all clips with matching audio in data_dir. """Generate LoRA mels for all clips with matching audio in data_dir.
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode → Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
@@ -383,8 +410,26 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
default (4.5) so the degraded mels the vocoder trains on are representative default (4.5) so the degraded mels the vocoder trains on are representative
of what it will see at inference time. of what it will see at inference time.
If cache_dir is provided, results are cached to disk and reused when
generation parameters haven't changed.
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors. Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
""" """
# ── Check cache ──────────────────────────────────────────────────────────
cache_path = None
if cache_dir is not None:
cache_key = _lora_mel_cache_key(
lora_adapter_path, data_dir, seed, num_steps,
cfg_strength, duration, sample_rate,
)
cache_path = Path(cache_dir) / f"lora_mels_{cache_key}.pt"
if cache_path.exists():
print(f"[BigVGAN] Loading cached LoRA mels: {cache_path.name}", flush=True)
cached = torch.load(str(cache_path), map_location="cpu", weights_only=True)
pairs = [(m, a) for m, a in zip(cached["mels"], cached["audios"])]
print(f"[BigVGAN] Loaded {len(pairs)} cached mel/audio pairs", flush=True)
return pairs
from selva_core.model.lora import apply_lora, load_lora from selva_core.model.lora import apply_lora, load_lora
from selva_core.model.flow_matching import FlowMatching from selva_core.model.flow_matching import FlowMatching
@@ -524,6 +569,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
soft_empty_cache() soft_empty_cache()
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True) print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
# ── Save cache ───────────────────────────────────────────────────────────
if cache_path is not None and pairs:
cache_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"mels": [m for m, _ in pairs],
"audios": [a for _, a in pairs],
}, str(cache_path))
print(f"[BigVGAN] Cached LoRA mels: {cache_path.name}", flush=True)
return pairs return pairs
@@ -756,6 +811,7 @@ class SelvaBigvganTrainer:
model, data_dir, str(lora_path), model, data_dir, str(lora_path),
device, dtype, sample_rate, device, dtype, sample_rate,
seq_cfg.duration, seed=seed, seq_cfg.duration, seed=seed,
cache_dir=out_path.parent,
) )
if not lora_mel_pairs: if not lora_mel_pairs:
raise RuntimeError( raise RuntimeError(