feat: cache pre-generated LoRA mels to disk for reuse
LoRA mel pre-generation runs a full ODE+CFG for every clip, which is slow. Cache results to a .pt file next to the output, keyed by a SHA-256 hash of the LoRA adapter content + generation parameters (seed, steps, CFG, duration, sample rate, npz file list). Automatically reused on subsequent runs when parameters haven't changed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,6 +27,8 @@ BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import json as _json
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -373,9 +375,34 @@ def _find_audio_for_npz(npz_path: Path):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _lora_mel_cache_key(lora_adapter_path, data_dir, seed, num_steps,
|
||||||
|
cfg_strength, duration, sample_rate):
|
||||||
|
"""Build a deterministic hash from all parameters that affect LoRA mel generation."""
|
||||||
|
# Hash the LoRA adapter file content (not path — same file moved = same cache)
|
||||||
|
h = hashlib.sha256()
|
||||||
|
with open(lora_adapter_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(1 << 20), b""):
|
||||||
|
h.update(chunk)
|
||||||
|
lora_hash = h.hexdigest()[:16]
|
||||||
|
|
||||||
|
# Hash the sorted .npz file list (names only — content is deterministic per name)
|
||||||
|
npz_names = sorted(p.name for p in Path(data_dir).glob("*.npz"))
|
||||||
|
|
||||||
|
key_data = _json.dumps({
|
||||||
|
"lora_hash": lora_hash,
|
||||||
|
"npz_files": npz_names,
|
||||||
|
"seed": seed,
|
||||||
|
"num_steps": num_steps,
|
||||||
|
"cfg_strength": cfg_strength,
|
||||||
|
"duration": duration,
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
}, sort_keys=True)
|
||||||
|
return hashlib.sha256(key_data.encode()).hexdigest()[:20]
|
||||||
|
|
||||||
|
|
||||||
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
||||||
sample_rate, duration, seed=42, num_steps=25,
|
sample_rate, duration, seed=42, num_steps=25,
|
||||||
cfg_strength=4.5):
|
cfg_strength=4.5, cache_dir=None):
|
||||||
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
"""Generate LoRA mels for all clips with matching audio in data_dir.
|
||||||
|
|
||||||
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
Uses the LoRA adapter to run full ODE generation with CFG → VAE decode →
|
||||||
@@ -383,8 +410,26 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
default (4.5) so the degraded mels the vocoder trains on are representative
|
default (4.5) so the degraded mels the vocoder trains on are representative
|
||||||
of what it will see at inference time.
|
of what it will see at inference time.
|
||||||
|
|
||||||
|
If cache_dir is provided, results are cached to disk and reused when
|
||||||
|
generation parameters haven't changed.
|
||||||
|
|
||||||
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
Returns list of (mel [n_mels, T_mel], audio [L]) CPU tensors.
|
||||||
"""
|
"""
|
||||||
|
# ── Check cache ──────────────────────────────────────────────────────────
|
||||||
|
cache_path = None
|
||||||
|
if cache_dir is not None:
|
||||||
|
cache_key = _lora_mel_cache_key(
|
||||||
|
lora_adapter_path, data_dir, seed, num_steps,
|
||||||
|
cfg_strength, duration, sample_rate,
|
||||||
|
)
|
||||||
|
cache_path = Path(cache_dir) / f"lora_mels_{cache_key}.pt"
|
||||||
|
if cache_path.exists():
|
||||||
|
print(f"[BigVGAN] Loading cached LoRA mels: {cache_path.name}", flush=True)
|
||||||
|
cached = torch.load(str(cache_path), map_location="cpu", weights_only=True)
|
||||||
|
pairs = [(m, a) for m, a in zip(cached["mels"], cached["audios"])]
|
||||||
|
print(f"[BigVGAN] Loaded {len(pairs)} cached mel/audio pairs", flush=True)
|
||||||
|
return pairs
|
||||||
|
|
||||||
from selva_core.model.lora import apply_lora, load_lora
|
from selva_core.model.lora import apply_lora, load_lora
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
@@ -524,6 +569,16 @@ def _pregenerate_lora_mels(model, data_dir, lora_adapter_path, device, dtype,
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
|
print(f"[BigVGAN] Pre-generated {len(pairs)} LoRA mel / clean audio pairs", flush=True)
|
||||||
|
|
||||||
|
# ── Save cache ───────────────────────────────────────────────────────────
|
||||||
|
if cache_path is not None and pairs:
|
||||||
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save({
|
||||||
|
"mels": [m for m, _ in pairs],
|
||||||
|
"audios": [a for _, a in pairs],
|
||||||
|
}, str(cache_path))
|
||||||
|
print(f"[BigVGAN] Cached LoRA mels: {cache_path.name}", flush=True)
|
||||||
|
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
@@ -756,6 +811,7 @@ class SelvaBigvganTrainer:
|
|||||||
model, data_dir, str(lora_path),
|
model, data_dir, str(lora_path),
|
||||||
device, dtype, sample_rate,
|
device, dtype, sample_rate,
|
||||||
seq_cfg.duration, seed=seed,
|
seq_cfg.duration, seed=seed,
|
||||||
|
cache_dir=out_path.parent,
|
||||||
)
|
)
|
||||||
if not lora_mel_pairs:
|
if not lora_mel_pairs:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
Reference in New Issue
Block a user