feat: add spectral_flatness and temporal_variance to eval metrics
spectral_flatness (Wiener entropy) — 0=tonal, 1=white noise. Rising value across steps directly flags noise contamination. temporal_variance — RMS std/mean per frame. Low = lifeless/compressed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -186,10 +186,26 @@ def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
|
|||||||
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
|
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
|
||||||
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
|
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
|
||||||
|
|
||||||
|
# Spectral flatness (Wiener entropy): geometric_mean / arithmetic_mean of power
|
||||||
|
# 0.0 = pure tone, 1.0 = white noise — rising value = noise contamination
|
||||||
|
log_power = np.log(power + 1e-12)
|
||||||
|
flatness = float(np.exp(log_power.mean()) / (power.mean() + 1e-12))
|
||||||
|
|
||||||
|
# Temporal energy variance — how dynamic the audio is
|
||||||
|
# Compute RMS per frame, take std. Low value = compressed/lifeless
|
||||||
|
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||||||
|
window = torch.hann_window(_SPEC_N_FFT)
|
||||||
|
stft_full = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||||||
|
window=window, return_complex=True)
|
||||||
|
frame_rms = stft_full.abs().pow(2).mean(dim=0).sqrt().numpy() # [n_frames]
|
||||||
|
temporal_variance = float(frame_rms.std() / (frame_rms.mean() + 1e-12))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"hf_energy_ratio": round(hf_ratio, 4),
|
"hf_energy_ratio": round(hf_ratio, 4),
|
||||||
"spectral_centroid_hz": round(centroid, 1),
|
"spectral_centroid_hz": round(centroid, 1),
|
||||||
"spectral_rolloff_hz": round(rolloff, 1),
|
"spectral_rolloff_hz": round(rolloff, 1),
|
||||||
|
"spectral_flatness": round(flatness, 4),
|
||||||
|
"temporal_variance": round(temporal_variance, 4),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -790,7 +806,9 @@ class SelvaLoraTrainer:
|
|||||||
spectral_metrics[step] = metrics
|
spectral_metrics[step] = metrics
|
||||||
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
|
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
|
||||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||||
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz", flush=True)
|
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz "
|
||||||
|
f"flatness={metrics['spectral_flatness']:.3f} "
|
||||||
|
f"temporal_var={metrics['temporal_variance']:.3f}", flush=True)
|
||||||
_save_spectrogram(wav, sr, wav_path)
|
_save_spectrogram(wav, sr, wav_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user