diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 6b31e1a..75c15a8 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -149,6 +149,72 @@ def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype, generator.train() +# --------------------------------------------------------------------------- +# Eval spectrogram rendering +# --------------------------------------------------------------------------- + +_SPEC_N_FFT = 2048 +_SPEC_HOP = 512 +_SPEC_DB_FLOOR = -80.0 +_SPEC_LOG_BINS = 256 + + +def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None: + """Save a log-frequency dB spectrogram PNG for an eval sample. + + wav: [1, L] float32 CPU tensor (mono). + """ + import numpy as np + from matplotlib.figure import Figure + from matplotlib.backends.backend_agg import FigureCanvasAgg + + wav_np = wav.squeeze(0).numpy() + hop = min(_SPEC_HOP, _SPEC_N_FFT) + window = torch.hann_window(_SPEC_N_FFT) + stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop, + window=window, return_complex=True) + mag = stft.abs().numpy() + db = 20.0 * np.log10(np.maximum(mag, 1e-8)) + db = np.maximum(db, db.max() + _SPEC_DB_FLOOR).astype(np.float32) + + # Log-frequency resampling + n_freqs = db.shape[0] + src_idx = np.logspace(0, np.log10(max(n_freqs - 1, 2)), _SPEC_LOG_BINS) + lo = np.floor(src_idx).astype(int).clip(0, n_freqs - 2) + frac = (src_idx - lo)[:, None] + spec = ((1 - frac) * db[lo] + frac * db[lo + 1]).astype(np.float32) + spec = spec[::-1] # low freq at bottom + + # Y-tick positions (Hz labels) + tgt_hz = [100, 500, 1000, 2000, 4000, 8000, 16000] + tpos, tlbl = [], [] + for hz in tgt_hz: + bin_f = hz * _SPEC_N_FFT / sr + if bin_f < 1 or bin_f >= n_freqs: + continue + pos = int(np.searchsorted(src_idx, bin_f)) + tpos.append(_SPEC_LOG_BINS - 1 - min(pos, _SPEC_LOG_BINS - 1)) + tlbl.append(f"{hz // 1000}k" if hz >= 1000 else str(hz)) + + vmin = float(np.percentile(spec, 2.0)) + vmax = float(np.percentile(spec, 99.5)) + + fig = Figure(figsize=(12, 3), dpi=120, tight_layout=True) + ax = fig.add_subplot(1, 1, 1) + im = ax.imshow(spec, aspect="auto", cmap="inferno", origin="upper", + vmin=vmin, vmax=vmax, interpolation="antialiased") + ax.set_yticks(tpos) + ax.set_yticklabels(tlbl, fontsize=8) + ax.set_ylabel("Hz", fontsize=9) + ax.set_xlabel("Time frames", fontsize=9) + ax.set_title(path.stem, fontsize=9) + fig.colorbar(im, ax=ax, label="dB", fraction=0.02, pad=0.01) + + canvas = FigureCanvasAgg(fig) + canvas.draw() + canvas.print_figure(str(path.with_suffix(".png")), dpi=120) + + # --------------------------------------------------------------------------- # Loss curve rendering # --------------------------------------------------------------------------- @@ -684,6 +750,10 @@ class SelvaLoraTrainer: import soundfile as sf sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) + try: + _save_spectrogram(wav, sr, wav_path) + except Exception as e: + print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True) last_step = step pbar_train.update(1)