feat: spectral metrics per eval sample in experiment summary
Computes hf_energy_ratio (>4kHz), spectral_centroid_hz, spectral_rolloff_hz at each save_every checkpoint. Logged to console and stored in experiment_summary.json under results.spectral_metrics[step]. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -429,6 +429,7 @@ class SelvaLoraScheduler:
|
||||
duration = time.monotonic() - t_start
|
||||
loss_history = r["loss_history"]
|
||||
grad_norm_history = r.get("grad_norm_history", [])
|
||||
spectral_metrics = r.get("spectral_metrics", {})
|
||||
run_start_step = r.get("start_step", 0)
|
||||
smoothed = _smooth_losses(loss_history) if loss_history else []
|
||||
|
||||
@@ -460,6 +461,7 @@ class SelvaLoraScheduler:
|
||||
),
|
||||
"loss_history": [round(v, 6) for v in loss_history],
|
||||
"grad_norm_history": grad_norm_history,
|
||||
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
|
||||
"log_interval": log_interval,
|
||||
"duration_seconds": round(duration, 1),
|
||||
}
|
||||
|
||||
@@ -159,6 +159,40 @@ _SPEC_DB_FLOOR = -80.0
|
||||
_SPEC_LOG_BINS = 256
|
||||
|
||||
|
||||
def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
|
||||
"""Compute spectral quality metrics for a mono [1, L] float32 CPU tensor.
|
||||
|
||||
Returns:
|
||||
hf_energy_ratio — energy above 4kHz / total energy (low bitrate → low value)
|
||||
spectral_centroid_hz — energy-weighted mean frequency
|
||||
spectral_rolloff_hz — frequency below which 85% of energy sits
|
||||
"""
|
||||
import numpy as np
|
||||
wav_np = wav.squeeze(0).numpy()
|
||||
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||||
window = torch.hann_window(_SPEC_N_FFT)
|
||||
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||||
window=window, return_complex=True)
|
||||
power = stft.abs().pow(2).mean(dim=1).numpy() # [n_freqs] averaged over time
|
||||
freqs = np.linspace(0, sr / 2, len(power))
|
||||
|
||||
total = power.sum() + 1e-12
|
||||
hf_mask = freqs >= 4000
|
||||
hf_ratio = float(power[hf_mask].sum() / total)
|
||||
|
||||
centroid = float((freqs * power).sum() / total)
|
||||
|
||||
cumsum = np.cumsum(power)
|
||||
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
|
||||
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
|
||||
|
||||
return {
|
||||
"hf_energy_ratio": round(hf_ratio, 4),
|
||||
"spectral_centroid_hz": round(centroid, 1),
|
||||
"spectral_rolloff_hz": round(rolloff, 1),
|
||||
}
|
||||
|
||||
|
||||
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
|
||||
"""Save a log-frequency dB spectrogram PNG for an eval sample.
|
||||
|
||||
@@ -629,9 +663,10 @@ class SelvaLoraTrainer:
|
||||
"record any loss — increase 'steps' or lower the resume checkpoint."
|
||||
)
|
||||
pbar_train = comfy.utils.ProgressBar(remaining)
|
||||
loss_history = []
|
||||
running_loss = 0.0
|
||||
loss_history = []
|
||||
running_loss = 0.0
|
||||
grad_norm_history = []
|
||||
spectral_metrics = {} # {step: {hf_energy_ratio, spectral_centroid_hz, spectral_rolloff_hz}}
|
||||
running_grad_norm = 0.0
|
||||
grad_norm_count = 0
|
||||
|
||||
@@ -751,9 +786,14 @@ class SelvaLoraTrainer:
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||
try:
|
||||
metrics = _spectral_metrics(wav, sr)
|
||||
spectral_metrics[step] = metrics
|
||||
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
|
||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz", flush=True)
|
||||
_save_spectrogram(wav, sr, wav_path)
|
||||
except Exception as e:
|
||||
print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True)
|
||||
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
||||
|
||||
last_step = step
|
||||
pbar_train.update(1)
|
||||
@@ -802,6 +842,7 @@ class SelvaLoraTrainer:
|
||||
"loss_curve": loss_curve,
|
||||
"loss_history": loss_history,
|
||||
"grad_norm_history": grad_norm_history,
|
||||
"spectral_metrics": spectral_metrics,
|
||||
"start_step": start_step,
|
||||
"meta": meta,
|
||||
"completed": True,
|
||||
|
||||
Reference in New Issue
Block a user