diff --git a/experiments/lora_optimized_dataset.json b/experiments/lora_optimized_dataset.json index 3194961..fd212d5 100644 --- a/experiments/lora_optimized_dataset.json +++ b/experiments/lora_optimized_dataset.json @@ -7,7 +7,7 @@ "rank": 128, "lr": 3e-4, "steps": 5000, - "batch_size": 4, + "batch_size": 16, "warmup_steps": 100, "save_every": 1000, "seed": 42, diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index e60a157..28486f9 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -224,6 +224,77 @@ def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict: } +def _reference_metrics(gen_wav: torch.Tensor, ref_wav: torch.Tensor, sr: int) -> dict: + """Compare generated eval sample against the original reference audio. + + gen_wav, ref_wav: [1, L] float32 CPU tensors (mono). + + Returns: + log_spectral_distance_db — RMS dB difference across freq bins (lower = better) + mel_cepstral_distortion — scaled L2 in log-mel space (lower = better) + per_band_correlation — mean Pearson r across mel bands (higher = better) + """ + from .selva_audio_preprocessors import _mel_filterbank + + L = min(gen_wav.shape[-1], ref_wav.shape[-1]) + gen = gen_wav[..., :L].squeeze(0) + ref = ref_wav[..., :L].squeeze(0) + + n_fft = _SPEC_N_FFT + hop = _SPEC_HOP + window = torch.hann_window(n_fft) + + gen_stft = torch.stft(gen, n_fft, hop, window=window, return_complex=True) + ref_stft = torch.stft(ref, n_fft, hop, window=window, return_complex=True) + + gen_mag = gen_stft.abs() + ref_mag = ref_stft.abs() + + T = min(gen_mag.shape[1], ref_mag.shape[1]) + gen_mag = gen_mag[:, :T] + ref_mag = ref_mag[:, :T] + + # Log-spectral distance (dB) + gen_db = 20.0 * torch.log10(gen_mag.clamp(min=1e-8)) + ref_db = 20.0 * torch.log10(ref_mag.clamp(min=1e-8)) + lsd = float(((gen_db - ref_db) ** 2).mean().sqrt()) + + # Mel-scale comparison + n_mels = 80 + fb = _mel_filterbank(sr, n_fft, n_mels, 0, sr // 2) # [n_mels, n_freqs] + + gen_mel = torch.matmul(fb, gen_mag).clamp(min=1e-8) + ref_mel = torch.matmul(fb, ref_mag).clamp(min=1e-8) + + gen_mel_log = torch.log(gen_mel) + ref_mel_log = torch.log(ref_mel) + + # Mel cepstral distortion (L2 in log-mel space, standard scaling) + mcd = float( + (10.0 / np.log(10)) + * ((gen_mel_log - ref_mel_log) ** 2).mean(0).sqrt().mean() + ) + + # Per-band Pearson correlation + gen_np = gen_mel_log.numpy() + ref_np = ref_mel_log.numpy() + correlations = [] + for b in range(n_mels): + g, r = gen_np[b], ref_np[b] + if g.std() < 1e-8 or r.std() < 1e-8: + continue + corr = np.corrcoef(g, r)[0, 1] + if np.isfinite(corr): + correlations.append(corr) + mean_corr = float(np.mean(correlations)) if correlations else 0.0 + + return { + "log_spectral_distance_db": round(lsd, 2), + "mel_cepstral_distortion": round(mcd, 2), + "per_band_correlation": round(mean_corr, 4), + } + + def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None: """Save a log-frequency dB spectrogram PNG for an eval sample. @@ -778,6 +849,19 @@ class SelvaLoraTrainer: f"(step {start_step + 1} → {steps}, batch_size={batch_size}, " f"timestep_mode={timestep_mode})\n", flush=True) + # Load reference audio for eval comparison against original + ref_wav = None + try: + npz_files = sorted(data_dir.glob("*.npz")) + if npz_files: + ref_audio_path = _find_audio(npz_files[0]) + if ref_audio_path is not None: + ref_wav = _load_audio(ref_audio_path, seq_cfg.sampling_rate, + seq_cfg.duration).unsqueeze(0) # [1, L] + print(f"[LoRA Trainer] Reference audio: {ref_audio_path.name}", flush=True) + except Exception as e: + print(f"[LoRA Trainer] Could not load reference audio: {e}", flush=True) + last_step = start_step completed = False try: @@ -911,6 +995,13 @@ class SelvaLoraTrainer: f"flatness={metrics['spectral_flatness']:.3f} " f"temporal_var={metrics['temporal_variance']:.3f}", flush=True) _save_spectrogram(wav, sr, wav_path) + if ref_wav is not None: + ref_m = _reference_metrics(wav, ref_wav, sr) + spectral_metrics[step].update(ref_m) + print(f"[LoRA Trainer] vs Reference: " + f"LSD={ref_m['log_spectral_distance_db']:.1f}dB " + f"MCD={ref_m['mel_cepstral_distortion']:.2f} " + f"corr={ref_m['per_band_correlation']:.3f}", flush=True) except Exception as e: print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)