feat: cache pre-generated LoRA mels to disk for reuse
LoRA mel pre-generation runs a full ODE+CFG for every clip, which is slow. Cache results to a .pt file next to the output, keyed by a SHA-256 hash of the LoRA adapter content + generation parameters (seed, steps, CFG, duration, sample rate, npz file list). Automatically reused on subsequent runs when parameters haven't changed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,6 +27,8 @@ BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import json as _json
|
||||
import random
|
||||
import threading
|
||||
from pathlib import Path
|
||||
@@ -373,9 +375,34 @@ def _find_audio_for_npz(npz_path: Path):
|
||||
return None
|
||||
|
||||
|
||||
def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps,
|
||||
cfg_strength, duration, sample_rate):
|
||||
"""Build a deterministic hash from all parameters that affect LoRA mel generation."""
|
||||
# Hash the LoRA adapter file content (not path — same file moved = same cache)
|
||||
h = hashlib.sha256()
|
||||
with open(lora_adapter_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(1 << 20), b""):
|
||||
h.update(chunk)
|
||||
lora_hash = h.hexdigest()[:16]
|
||||
|
||||
# Hash the sorted .npz file list (names only — content is deterministic per name)
|
||||
npz_names = sorted(p.name for p in Path(data_dir).glob("*.npz"))
|
||||
|
||||
key_data = _json.dumps({
|
||||
"lora_hash": lora_hash,
|
||||
"npz_files": npz_names,
|
||||
"seed": seed,
|
||||
"num_steps": num_steps,
|
||||
"cfg_strength": cfg_strength,
|
||||
"duration": duration,
|
||||
"sample_rate": sample_rate,
|
||||
}, sort_keys=True)
|
||||
return hashlib.sha256(key_data.encode()).hexdigest()[:20]
|
||||
|
||||
|
||||
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
sample_rate, duration, seed=42, num_steps=25,
|
||||
cfg_strength=4.5):
|
||||
cfg_strength=4.5, cache_dir=None):
|
||||
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||
|
||||
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
||||
@@ -383,8 +410,26 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
default (4.5) so the degraded mels the vocoder trains on are representative
|
||||
of what it will see at inference time.
|
||||
|
||||
If cache_dir is provided, results are cached to disk and reused when
|
||||
generation parameters haven't changed.
|
||||
|
||||
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||
"""
|
||||
# ── Check cache ──────────────────────────────────────────────────────────
|
||||
cache_path = None
|
||||
if cache_dir is not None:
|
||||
cache_key = _lora_mel_cache_key(
|
||||
lora_adapter_path, data_dir, seed, num_steps,
|
||||
cfg_strength, duration, sample_rate,
|
||||
)
|
||||
cache_path = Path(cache_dir) / f"lora_mels_{cache_key}.pt"
|
||||
if cache_path.exists():
|
||||
print(f"[BigVGAN] Loading cached LoRA mels: {cache_path.name}", flush=True)
|
||||
cached = torch.load(str(cache_path), map_location="cpu", weights_only=True)
|
||||
pairs = [(m, a) for m, a in zip(cached["mels"], cached["audios"])]
|
||||
print(f"[BigVGAN] Loaded {len(pairs)} cached mel/audio pairs", flush=True)
|
||||
return pairs
|
||||
|
||||
from selva_core.model.lora import apply_lora, load_lora
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
@@ -524,6 +569,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||
soft_empty_cache()
|
||||
|
||||
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
|
||||
|
||||
# ── Save cache ───────────────────────────────────────────────────────────
|
||||
if cache_path is not None and pairs:
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save({
|
||||
"mels": [m for m, _ in pairs],
|
||||
"audios": [a for _, a in pairs],
|
||||
}, str(cache_path))
|
||||
print(f"[BigVGAN] Cached LoRA mels: {cache_path.name}", flush=True)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
@@ -756,6 +811,7 @@ class SelvaBigvganTrainer:
|
||||
model, data_dir, str(lora_path),
|
||||
device, dtype, sample_rate,
|
||||
seq_cfg.duration, seed=seed,
|
||||
cache_dir=out_path.parent,
|
||||
)
|
||||
if not lora_mel_pairs:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user