diff --git a/nodes/selva_lora_scheduler.py b/nodes/selva_lora_scheduler.py index 134ff3d..03270e7 100644 --- a/nodes/selva_lora_scheduler.py +++ b/nodes/selva_lora_scheduler.py @@ -429,6 +429,7 @@ class SelvaLoraScheduler: duration = time.monotonic() - t_start loss_history = r["loss_history"] grad_norm_history = r.get("grad_norm_history", []) + spectral_metrics = r.get("spectral_metrics", {}) run_start_step = r.get("start_step", 0) smoothed = _smooth_losses(loss_history) if loss_history else [] @@ -460,6 +461,7 @@ class SelvaLoraScheduler: ), "loss_history": [round(v, 6) for v in loss_history], "grad_norm_history": grad_norm_history, + "spectral_metrics": {str(k): v for k, v in spectral_metrics.items()}, "log_interval": log_interval, "duration_seconds": round(duration, 1), } diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 75c15a8..02f6005 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -159,6 +159,40 @@ _SPEC_DB_FLOOR = -80.0 _SPEC_LOG_BINS = 256 +def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict: + """Compute spectral quality metrics for a mono [1, L] float32 CPU tensor. + + Returns: + hf_energy_ratio — energy above 4kHz / total energy (low bitrate → low value) + spectral_centroid_hz — energy-weighted mean frequency + spectral_rolloff_hz — frequency below which 85% of energy sits + """ + import numpy as np + wav_np = wav.squeeze(0).numpy() + hop = min(_SPEC_HOP, _SPEC_N_FFT) + window = torch.hann_window(_SPEC_N_FFT) + stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop, + window=window, return_complex=True) + power = stft.abs().pow(2).mean(dim=1).numpy() # [n_freqs] averaged over time + freqs = np.linspace(0, sr / 2, len(power)) + + total = power.sum() + 1e-12 + hf_mask = freqs >= 4000 + hf_ratio = float(power[hf_mask].sum() / total) + + centroid = float((freqs * power).sum() / total) + + cumsum = np.cumsum(power) + rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1]) + rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)]) + + return { + "hf_energy_ratio": round(hf_ratio, 4), + "spectral_centroid_hz": round(centroid, 1), + "spectral_rolloff_hz": round(rolloff, 1), + } + + def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None: """Save a log-frequency dB spectrogram PNG for an eval sample. @@ -629,9 +663,10 @@ class SelvaLoraTrainer: "record any loss — increase 'steps' or lower the resume checkpoint." ) pbar_train = comfy.utils.ProgressBar(remaining) - loss_history = [] - running_loss = 0.0 + loss_history = [] + running_loss = 0.0 grad_norm_history = [] + spectral_metrics = {} # {step: {hf_energy_ratio, spectral_centroid_hz, spectral_rolloff_hz}} running_grad_norm = 0.0 grad_norm_count = 0 @@ -751,9 +786,14 @@ class SelvaLoraTrainer: sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) try: + metrics = _spectral_metrics(wav, sr) + spectral_metrics[step] = metrics + print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} " + f"centroid={metrics['spectral_centroid_hz']:.0f}Hz " + f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz", flush=True) _save_spectrogram(wav, sr, wav_path) except Exception as e: - print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True) + print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True) last_step = step pbar_train.update(1) @@ -802,6 +842,7 @@ class SelvaLoraTrainer: "loss_curve": loss_curve, "loss_history": loss_history, "grad_norm_history": grad_norm_history, + "spectral_metrics": spectral_metrics, "start_step": start_step, "meta": meta, "completed": True,