feat: add reference audio comparison metrics to LoRA trainer eval
New _reference_metrics() computes LSD, MCD, and per-band correlation between eval samples and the original source audio at each checkpoint. Loads reference audio once before the training loop and logs metrics alongside existing spectral metrics. Also fix batch_size in lora_optimized_dataset.json (4 -> 16). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
"rank": 128,
|
||||
"lr": 3e-4,
|
||||
"steps": 5000,
|
||||
"batch_size": 4,
|
||||
"batch_size": 16,
|
||||
"warmup_steps": 100,
|
||||
"save_every": 1000,
|
||||
"seed": 42,
|
||||
|
||||
@@ -224,6 +224,77 @@ def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _reference_metrics(gen_wav: torch.Tensor, ref_wav: torch.Tensor, sr: int) -> dict:
|
||||
"""Compare generated eval sample against the original reference audio.
|
||||
|
||||
gen_wav, ref_wav: [1, L] float32 CPU tensors (mono).
|
||||
|
||||
Returns:
|
||||
log_spectral_distance_db — RMS dB difference across freq bins (lower = better)
|
||||
mel_cepstral_distortion — scaled L2 in log-mel space (lower = better)
|
||||
per_band_correlation — mean Pearson r across mel bands (higher = better)
|
||||
"""
|
||||
from .selva_audio_preprocessors import _mel_filterbank
|
||||
|
||||
L = min(gen_wav.shape[-1], ref_wav.shape[-1])
|
||||
gen = gen_wav[..., :L].squeeze(0)
|
||||
ref = ref_wav[..., :L].squeeze(0)
|
||||
|
||||
n_fft = _SPEC_N_FFT
|
||||
hop = _SPEC_HOP
|
||||
window = torch.hann_window(n_fft)
|
||||
|
||||
gen_stft = torch.stft(gen, n_fft, hop, window=window, return_complex=True)
|
||||
ref_stft = torch.stft(ref, n_fft, hop, window=window, return_complex=True)
|
||||
|
||||
gen_mag = gen_stft.abs()
|
||||
ref_mag = ref_stft.abs()
|
||||
|
||||
T = min(gen_mag.shape[1], ref_mag.shape[1])
|
||||
gen_mag = gen_mag[:, :T]
|
||||
ref_mag = ref_mag[:, :T]
|
||||
|
||||
# Log-spectral distance (dB)
|
||||
gen_db = 20.0 * torch.log10(gen_mag.clamp(min=1e-8))
|
||||
ref_db = 20.0 * torch.log10(ref_mag.clamp(min=1e-8))
|
||||
lsd = float(((gen_db - ref_db) ** 2).mean().sqrt())
|
||||
|
||||
# Mel-scale comparison
|
||||
n_mels = 80
|
||||
fb = _mel_filterbank(sr, n_fft, n_mels, 0, sr // 2) # [n_mels, n_freqs]
|
||||
|
||||
gen_mel = torch.matmul(fb, gen_mag).clamp(min=1e-8)
|
||||
ref_mel = torch.matmul(fb, ref_mag).clamp(min=1e-8)
|
||||
|
||||
gen_mel_log = torch.log(gen_mel)
|
||||
ref_mel_log = torch.log(ref_mel)
|
||||
|
||||
# Mel cepstral distortion (L2 in log-mel space, standard scaling)
|
||||
mcd = float(
|
||||
(10.0 / np.log(10))
|
||||
* ((gen_mel_log - ref_mel_log) ** 2).mean(0).sqrt().mean()
|
||||
)
|
||||
|
||||
# Per-band Pearson correlation
|
||||
gen_np = gen_mel_log.numpy()
|
||||
ref_np = ref_mel_log.numpy()
|
||||
correlations = []
|
||||
for b in range(n_mels):
|
||||
g, r = gen_np[b], ref_np[b]
|
||||
if g.std() < 1e-8 or r.std() < 1e-8:
|
||||
continue
|
||||
corr = np.corrcoef(g, r)[0, 1]
|
||||
if np.isfinite(corr):
|
||||
correlations.append(corr)
|
||||
mean_corr = float(np.mean(correlations)) if correlations else 0.0
|
||||
|
||||
return {
|
||||
"log_spectral_distance_db": round(lsd, 2),
|
||||
"mel_cepstral_distortion": round(mcd, 2),
|
||||
"per_band_correlation": round(mean_corr, 4),
|
||||
}
|
||||
|
||||
|
||||
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
|
||||
"""Save a log-frequency dB spectrogram PNG for an eval sample.
|
||||
|
||||
@@ -778,6 +849,19 @@ class SelvaLoraTrainer:
|
||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||||
f"timestep_mode={timestep_mode})\n", flush=True)
|
||||
|
||||
# Load reference audio for eval comparison against original
|
||||
ref_wav = None
|
||||
try:
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
if npz_files:
|
||||
ref_audio_path = _find_audio(npz_files[0])
|
||||
if ref_audio_path is not None:
|
||||
ref_wav = _load_audio(ref_audio_path, seq_cfg.sampling_rate,
|
||||
seq_cfg.duration).unsqueeze(0) # [1, L]
|
||||
print(f"[LoRA Trainer] Reference audio: {ref_audio_path.name}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[LoRA Trainer] Could not load reference audio: {e}", flush=True)
|
||||
|
||||
last_step = start_step
|
||||
completed = False
|
||||
try:
|
||||
@@ -911,6 +995,13 @@ class SelvaLoraTrainer:
|
||||
f"flatness={metrics['spectral_flatness']:.3f} "
|
||||
f"temporal_var={metrics['temporal_variance']:.3f}", flush=True)
|
||||
_save_spectrogram(wav, sr, wav_path)
|
||||
if ref_wav is not None:
|
||||
ref_m = _reference_metrics(wav, ref_wav, sr)
|
||||
spectral_metrics[step].update(ref_m)
|
||||
print(f"[LoRA Trainer] vs Reference: "
|
||||
f"LSD={ref_m['log_spectral_distance_db']:.1f}dB "
|
||||
f"MCD={ref_m['mel_cepstral_distortion']:.2f} "
|
||||
f"corr={ref_m['per_band_correlation']:.3f}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user