feat: save spectrogram PNG alongside each eval sample

Log-frequency dB spectrogram (inferno colormap, 100Hz–16kHz) saved as
step_XXXXX.png next to step_XXXXX.wav in samples/ subfolder.
Makes high-frequency rolloff (low bitrate signature) immediately visible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 12:42:34 +02:00
parent 8717af2728
commit c4687521ef
+70
View File
@@ -149,6 +149,72 @@ def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
generator.train() generator.train()
# ---------------------------------------------------------------------------
# Eval spectrogram rendering
# ---------------------------------------------------------------------------
_SPEC_N_FFT = 2048
_SPEC_HOP = 512
_SPEC_DB_FLOOR = -80.0
_SPEC_LOG_BINS = 256
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
"""Save a log-frequency dB spectrogram PNG for an eval sample.
wav: [1, L] float32 CPU tensor (mono).
"""
import numpy as np
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
wav_np = wav.squeeze(0).numpy()
hop = min(_SPEC_HOP, _SPEC_N_FFT)
window = torch.hann_window(_SPEC_N_FFT)
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
window=window, return_complex=True)
mag = stft.abs().numpy()
db = 20.0 * np.log10(np.maximum(mag, 1e-8))
db = np.maximum(db, db.max() + _SPEC_DB_FLOOR).astype(np.float32)
# Log-frequency resampling
n_freqs = db.shape[0]
src_idx = np.logspace(0, np.log10(max(n_freqs - 1, 2)), _SPEC_LOG_BINS)
lo = np.floor(src_idx).astype(int).clip(0, n_freqs - 2)
frac = (src_idx - lo)[:, None]
spec = ((1 - frac) * db[lo] + frac * db[lo + 1]).astype(np.float32)
spec = spec[::-1] # low freq at bottom
# Y-tick positions (Hz labels)
tgt_hz = [100, 500, 1000, 2000, 4000, 8000, 16000]
tpos, tlbl = [], []
for hz in tgt_hz:
bin_f = hz * _SPEC_N_FFT / sr
if bin_f < 1 or bin_f >= n_freqs:
continue
pos = int(np.searchsorted(src_idx, bin_f))
tpos.append(_SPEC_LOG_BINS - 1 - min(pos, _SPEC_LOG_BINS - 1))
tlbl.append(f"{hz // 1000}k" if hz >= 1000 else str(hz))
vmin = float(np.percentile(spec, 2.0))
vmax = float(np.percentile(spec, 99.5))
fig = Figure(figsize=(12, 3), dpi=120, tight_layout=True)
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(spec, aspect="auto", cmap="inferno", origin="upper",
vmin=vmin, vmax=vmax, interpolation="antialiased")
ax.set_yticks(tpos)
ax.set_yticklabels(tlbl, fontsize=8)
ax.set_ylabel("Hz", fontsize=9)
ax.set_xlabel("Time frames", fontsize=9)
ax.set_title(path.stem, fontsize=9)
fig.colorbar(im, ax=ax, label="dB", fraction=0.02, pad=0.01)
canvas = FigureCanvasAgg(fig)
canvas.draw()
canvas.print_figure(str(path.with_suffix(".png")), dpi=120)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Loss curve rendering # Loss curve rendering
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -684,6 +750,10 @@ class SelvaLoraTrainer:
import soundfile as sf import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
try:
_save_spectrogram(wav, sr, wav_path)
except Exception as e:
print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True)
last_step = step last_step = step
pbar_train.update(1) pbar_train.update(1)