From 65dc54949457c82d053f70ae1f0af17c9a173f4f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 15:04:07 +0200 Subject: [PATCH] feat: add reference audio comparison metrics to LoRA trainer eval New _reference_metrics() computes LSD, MCD, and per-band correlation between eval samples and the original source audio at each checkpoint. Loads reference audio once before the training loop and logs metrics alongside existing spectral metrics. Also fix batch_size in lora_optimized_dataset.json (4 -> 16). Co-Authored-By: Claude Opus 4.6 --- experiments/lora_optimized_dataset.json | 2 +- nodes/selva_lora_trainer.py | 91 +++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/experiments/lora_optimized_dataset.json b/experiments/lora_optimized_dataset.json index 3194961..fd212d5 100644 --- a/experiments/lora_optimized_dataset.json +++ b/experiments/lora_optimized_dataset.json @@ -7,7 +7,7 @@ "rank": 128, "lr": 3e-4, "steps": 5000, - "batch_size": 4, + "batch_size": 16, "warmup_steps": 100, "save_every": 1000, "seed": 42, diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index e60a157..28486f9 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -224,6 +224,77 @@ def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict: } +def _reference_metrics(gen_wav: torch.Tensor, ref_wav: torch.Tensor, sr: int) -> dict: + """Compare generated eval sample against the original reference audio. + + gen_wav, ref_wav: [1, L] float32 CPU tensors (mono). + + Returns: + log_spectral_distance_db — RMS dB difference across freq bins (lower = better) + mel_cepstral_distortion — scaled L2 in log-mel space (lower = better) + per_band_correlation — mean Pearson r across mel bands (higher = better) + """ + from .selva_audio_preprocessors import _mel_filterbank + + L = min(gen_wav.shape[-1], ref_wav.shape[-1]) + gen = gen_wav[..., :L].squeeze(0) + ref = ref_wav[..., :L].squeeze(0) + + n_fft = _SPEC_N_FFT + hop = _SPEC_HOP + window = torch.hann_window(n_fft) + + gen_stft = torch.stft(gen, n_fft, hop, window=window, return_complex=True) + ref_stft = torch.stft(ref, n_fft, hop, window=window, return_complex=True) + + gen_mag = gen_stft.abs() + ref_mag = ref_stft.abs() + + T = min(gen_mag.shape[1], ref_mag.shape[1]) + gen_mag = gen_mag[:, :T] + ref_mag = ref_mag[:, :T] + + # Log-spectral distance (dB) + gen_db = 20.0 * torch.log10(gen_mag.clamp(min=1e-8)) + ref_db = 20.0 * torch.log10(ref_mag.clamp(min=1e-8)) + lsd = float(((gen_db - ref_db) ** 2).mean().sqrt()) + + # Mel-scale comparison + n_mels = 80 + fb = _mel_filterbank(sr, n_fft, n_mels, 0, sr // 2) # [n_mels, n_freqs] + + gen_mel = torch.matmul(fb, gen_mag).clamp(min=1e-8) + ref_mel = torch.matmul(fb, ref_mag).clamp(min=1e-8) + + gen_mel_log = torch.log(gen_mel) + ref_mel_log = torch.log(ref_mel) + + # Mel cepstral distortion (L2 in log-mel space, standard scaling) + mcd = float( + (10.0 / np.log(10)) + * ((gen_mel_log - ref_mel_log) ** 2).mean(0).sqrt().mean() + ) + + # Per-band Pearson correlation + gen_np = gen_mel_log.numpy() + ref_np = ref_mel_log.numpy() + correlations = [] + for b in range(n_mels): + g, r = gen_np[b], ref_np[b] + if g.std() < 1e-8 or r.std() < 1e-8: + continue + corr = np.corrcoef(g, r)[0, 1] + if np.isfinite(corr): + correlations.append(corr) + mean_corr = float(np.mean(correlations)) if correlations else 0.0 + + return { + "log_spectral_distance_db": round(lsd, 2), + "mel_cepstral_distortion": round(mcd, 2), + "per_band_correlation": round(mean_corr, 4), + } + + def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None: """Save a log-frequency dB spectrogram PNG for an eval sample. @@ -778,6 +849,19 @@ class SelvaLoraTrainer: f"(step {start_step + 1} → {steps}, batch_size={batch_size}, " f"timestep_mode={timestep_mode})\n", flush=True) + # Load reference audio for eval comparison against original + ref_wav = None + try: + npz_files = sorted(data_dir.glob("*.npz")) + if npz_files: + ref_audio_path = _find_audio(npz_files[0]) + if ref_audio_path is not None: + ref_wav = _load_audio(ref_audio_path, seq_cfg.sampling_rate, + seq_cfg.duration).unsqueeze(0) # [1, L] + print(f"[LoRA Trainer] Reference audio: {ref_audio_path.name}", flush=True) + except Exception as e: + print(f"[LoRA Trainer] Could not load reference audio: {e}", flush=True) + last_step = start_step completed = False try: @@ -911,6 +995,13 @@ class SelvaLoraTrainer: f"flatness={metrics['spectral_flatness']:.3f} " f"temporal_var={metrics['temporal_variance']:.3f}", flush=True) _save_spectrogram(wav, sr, wav_path) + if ref_wav is not None: + ref_m = _reference_metrics(wav, ref_wav, sr) + spectral_metrics[step].update(ref_m) + print(f"[LoRA Trainer] vs Reference: " + f"LSD={ref_m['log_spectral_distance_db']:.1f}dB " + f"MCD={ref_m['mel_cepstral_distortion']:.2f} " + f"corr={ref_m['per_band_correlation']:.3f}", flush=True) except Exception as e: print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)