feat: spectral metrics per eval sample in experiment summary

Computes hf_energy_ratio (>4kHz), spectral_centroid_hz, spectral_rolloff_hz
at each save_every checkpoint. Logged to console and stored in
experiment_summary.json under results.spectral_metrics[step].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 12:44:43 +02:00
parent c4687521ef
commit 2861327016
2 changed files with 46 additions and 3 deletions
+2
View File
@@ -429,6 +429,7 @@ class SelvaLoraScheduler:
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
grad_norm_history = r.get("grad_norm_history", [])
spectral_metrics = r.get("spectral_metrics", {})
run_start_step = r.get("start_step", 0)
smoothed = _smooth_losses(loss_history) if loss_history else []
@@ -460,6 +461,7 @@ class SelvaLoraScheduler:
),
"loss_history": [round(v, 6) for v in loss_history],
"grad_norm_history": grad_norm_history,
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
+42 -1
View File
@@ -159,6 +159,40 @@ _SPEC_DB_FLOOR = -80.0
_SPEC_LOG_BINS = 256
def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
"""Compute spectral quality metrics for a mono [1, L] float32 CPU tensor.
Returns:
hf_energy_ratio — energy above 4kHz / total energy (low bitrate → low value)
spectral_centroid_hz — energy-weighted mean frequency
spectral_rolloff_hz — frequency below which 85% of energy sits
"""
import numpy as np
wav_np = wav.squeeze(0).numpy()
hop = min(_SPEC_HOP, _SPEC_N_FFT)
window = torch.hann_window(_SPEC_N_FFT)
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
window=window, return_complex=True)
power = stft.abs().pow(2).mean(dim=1).numpy() # [n_freqs] averaged over time
freqs = np.linspace(0, sr / 2, len(power))
total = power.sum() + 1e-12
hf_mask = freqs >= 4000
hf_ratio = float(power[hf_mask].sum() / total)
centroid = float((freqs * power).sum() / total)
cumsum = np.cumsum(power)
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
return {
"hf_energy_ratio": round(hf_ratio, 4),
"spectral_centroid_hz": round(centroid, 1),
"spectral_rolloff_hz": round(rolloff, 1),
}
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
"""Save a log-frequency dB spectrogram PNG for an eval sample.
@@ -632,6 +666,7 @@ class SelvaLoraTrainer:
loss_history = []
running_loss = 0.0
grad_norm_history = []
spectral_metrics = {} # {step: {hf_energy_ratio, spectral_centroid_hz, spectral_rolloff_hz}}
running_grad_norm = 0.0
grad_norm_count = 0
@@ -751,9 +786,14 @@ class SelvaLoraTrainer:
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
try:
metrics = _spectral_metrics(wav, sr)
spectral_metrics[step] = metrics
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz", flush=True)
_save_spectrogram(wav, sr, wav_path)
except Exception as e:
print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True)
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
last_step = step
pbar_train.update(1)
@@ -802,6 +842,7 @@ class SelvaLoraTrainer:
"loss_curve": loss_curve,
"loss_history": loss_history,
"grad_norm_history": grad_norm_history,
"spectral_metrics": spectral_metrics,
"start_step": start_step,
"meta": meta,
"completed": True,