feat: add reference audio comparison metrics to LoRA trainer eval

New _reference_metrics() computes LSD, MCD, and per-band correlation
between eval samples and the original source audio at each checkpoint.
Loads reference audio once before the training loop and logs metrics
alongside existing spectral metrics.

Also fix batch_size in lora_optimized_dataset.json (4 -> 16).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 15:04:07 +02:00
parent f745e241c4
commit 65dc549494
2 changed files with 92 additions and 1 deletions
+1 -1
View File
@@ -7,7 +7,7 @@
"rank": 128, "rank": 128,
"lr": 3e-4, "lr": 3e-4,
"steps": 5000, "steps": 5000,
"batch_size": 4, "batch_size": 16,
"warmup_steps": 100, "warmup_steps": 100,
"save_every": 1000, "save_every": 1000,
"seed": 42, "seed": 42,
+91
View File
@@ -224,6 +224,77 @@ def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
} }
def _reference_metrics(gen_wav: torch.Tensor, ref_wav: torch.Tensor, sr: int) -> dict:
"""Compare generated eval sample against the original reference audio.
gen_wav, ref_wav: [1, L] float32 CPU tensors (mono).
Returns:
log_spectral_distance_db — RMS dB difference across freq bins (lower = better)
mel_cepstral_distortion — scaled L2 in log-mel space (lower = better)
per_band_correlation — mean Pearson r across mel bands (higher = better)
"""
from .selva_audio_preprocessors import _mel_filterbank
L = min(gen_wav.shape[-1], ref_wav.shape[-1])
gen = gen_wav[..., :L].squeeze(0)
ref = ref_wav[..., :L].squeeze(0)
n_fft = _SPEC_N_FFT
hop = _SPEC_HOP
window = torch.hann_window(n_fft)
gen_stft = torch.stft(gen, n_fft, hop, window=window, return_complex=True)
ref_stft = torch.stft(ref, n_fft, hop, window=window, return_complex=True)
gen_mag = gen_stft.abs()
ref_mag = ref_stft.abs()
T = min(gen_mag.shape[1], ref_mag.shape[1])
gen_mag = gen_mag[:, :T]
ref_mag = ref_mag[:, :T]
# Log-spectral distance (dB)
gen_db = 20.0 * torch.log10(gen_mag.clamp(min=1e-8))
ref_db = 20.0 * torch.log10(ref_mag.clamp(min=1e-8))
lsd = float(((gen_db - ref_db) ** 2).mean().sqrt())
# Mel-scale comparison
n_mels = 80
fb = _mel_filterbank(sr, n_fft, n_mels, 0, sr // 2) # [n_mels, n_freqs]
gen_mel = torch.matmul(fb, gen_mag).clamp(min=1e-8)
ref_mel = torch.matmul(fb, ref_mag).clamp(min=1e-8)
gen_mel_log = torch.log(gen_mel)
ref_mel_log = torch.log(ref_mel)
# Mel cepstral distortion (L2 in log-mel space, standard scaling)
mcd = float(
(10.0 / np.log(10))
* ((gen_mel_log - ref_mel_log) ** 2).mean(0).sqrt().mean()
)
# Per-band Pearson correlation
gen_np = gen_mel_log.numpy()
ref_np = ref_mel_log.numpy()
correlations = []
for b in range(n_mels):
g, r = gen_np[b], ref_np[b]
if g.std() < 1e-8 or r.std() < 1e-8:
continue
corr = np.corrcoef(g, r)[0, 1]
if np.isfinite(corr):
correlations.append(corr)
mean_corr = float(np.mean(correlations)) if correlations else 0.0
return {
"log_spectral_distance_db": round(lsd, 2),
"mel_cepstral_distortion": round(mcd, 2),
"per_band_correlation": round(mean_corr, 4),
}
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None: def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
"""Save a log-frequency dB spectrogram PNG for an eval sample. """Save a log-frequency dB spectrogram PNG for an eval sample.
@@ -778,6 +849,19 @@ class SelvaLoraTrainer:
f"(step {start_step + 1}{steps}, batch_size={batch_size}, " f"(step {start_step + 1}{steps}, batch_size={batch_size}, "
f"timestep_mode={timestep_mode})\n", flush=True) f"timestep_mode={timestep_mode})\n", flush=True)
# Load reference audio for eval comparison against original
ref_wav = None
try:
npz_files = sorted(data_dir.glob("*.npz"))
if npz_files:
ref_audio_path = _find_audio(npz_files[0])
if ref_audio_path is not None:
ref_wav = _load_audio(ref_audio_path, seq_cfg.sampling_rate,
seq_cfg.duration).unsqueeze(0) # [1, L]
print(f"[LoRA Trainer] Reference audio: {ref_audio_path.name}", flush=True)
except Exception as e:
print(f"[LoRA Trainer] Could not load reference audio: {e}", flush=True)
last_step = start_step last_step = start_step
completed = False completed = False
try: try:
@@ -911,6 +995,13 @@ class SelvaLoraTrainer:
f"flatness={metrics['spectral_flatness']:.3f} " f"flatness={metrics['spectral_flatness']:.3f} "
f"temporal_var={metrics['temporal_variance']:.3f}", flush=True) f"temporal_var={metrics['temporal_variance']:.3f}", flush=True)
_save_spectrogram(wav, sr, wav_path) _save_spectrogram(wav, sr, wav_path)
if ref_wav is not None:
ref_m = _reference_metrics(wav, ref_wav, sr)
spectral_metrics[step].update(ref_m)
print(f"[LoRA Trainer] vs Reference: "
f"LSD={ref_m['log_spectral_distance_db']:.1f}dB "
f"MCD={ref_m['mel_cepstral_distortion']:.2f} "
f"corr={ref_m['per_band_correlation']:.3f}", flush=True)
except Exception as e: except Exception as e:
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True) print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)