feat: add spectral_flatness and temporal_variance to eval metrics

spectral_flatness (Wiener entropy) — 0=tonal, 1=white noise.
Rising value across steps directly flags noise contamination.
temporal_variance — RMS std/mean per frame. Low = lifeless/compressed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 12:45:40 +02:00
parent 2861327016
commit fec5c86f09
+19 -1
View File
@@ -186,10 +186,26 @@ def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
# Spectral flatness (Wiener entropy): geometric_mean / arithmetic_mean of power
# 0.0 = pure tone, 1.0 = white noise — rising value = noise contamination
log_power = np.log(power + 1e-12)
flatness = float(np.exp(log_power.mean()) / (power.mean() + 1e-12))
# Temporal energy variance — how dynamic the audio is
# Compute RMS per frame, take std. Low value = compressed/lifeless
hop = min(_SPEC_HOP, _SPEC_N_FFT)
window = torch.hann_window(_SPEC_N_FFT)
stft_full = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
window=window, return_complex=True)
frame_rms = stft_full.abs().pow(2).mean(dim=0).sqrt().numpy() # [n_frames]
temporal_variance = float(frame_rms.std() / (frame_rms.mean() + 1e-12))
return {
"hf_energy_ratio": round(hf_ratio, 4),
"spectral_centroid_hz": round(centroid, 1),
"spectral_rolloff_hz": round(rolloff, 1),
"spectral_flatness": round(flatness, 4),
"temporal_variance": round(temporal_variance, 4),
}
@@ -790,7 +806,9 @@ class SelvaLoraTrainer:
spectral_metrics[step] = metrics
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz", flush=True)
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz "
f"flatness={metrics['spectral_flatness']:.3f} "
f"temporal_var={metrics['temporal_variance']:.3f}", flush=True)
_save_spectrogram(wav, sr, wav_path)
except Exception as e:
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)