feat: save spectrogram PNG alongside each eval sample
Log-frequency dB spectrogram (inferno colormap, 100Hz–16kHz) saved as step_XXXXX.png next to step_XXXXX.wav in samples/ subfolder. Makes high-frequency rolloff (low bitrate signature) immediately visible. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -149,6 +149,72 @@ def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
||||
generator.train()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Eval spectrogram rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SPEC_N_FFT = 2048
|
||||
_SPEC_HOP = 512
|
||||
_SPEC_DB_FLOOR = -80.0
|
||||
_SPEC_LOG_BINS = 256
|
||||
|
||||
|
||||
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
|
||||
"""Save a log-frequency dB spectrogram PNG for an eval sample.
|
||||
|
||||
wav: [1, L] float32 CPU tensor (mono).
|
||||
"""
|
||||
import numpy as np
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
wav_np = wav.squeeze(0).numpy()
|
||||
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||||
window = torch.hann_window(_SPEC_N_FFT)
|
||||
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||||
window=window, return_complex=True)
|
||||
mag = stft.abs().numpy()
|
||||
db = 20.0 * np.log10(np.maximum(mag, 1e-8))
|
||||
db = np.maximum(db, db.max() + _SPEC_DB_FLOOR).astype(np.float32)
|
||||
|
||||
# Log-frequency resampling
|
||||
n_freqs = db.shape[0]
|
||||
src_idx = np.logspace(0, np.log10(max(n_freqs - 1, 2)), _SPEC_LOG_BINS)
|
||||
lo = np.floor(src_idx).astype(int).clip(0, n_freqs - 2)
|
||||
frac = (src_idx - lo)[:, None]
|
||||
spec = ((1 - frac) * db[lo] + frac * db[lo + 1]).astype(np.float32)
|
||||
spec = spec[::-1] # low freq at bottom
|
||||
|
||||
# Y-tick positions (Hz labels)
|
||||
tgt_hz = [100, 500, 1000, 2000, 4000, 8000, 16000]
|
||||
tpos, tlbl = [], []
|
||||
for hz in tgt_hz:
|
||||
bin_f = hz * _SPEC_N_FFT / sr
|
||||
if bin_f < 1 or bin_f >= n_freqs:
|
||||
continue
|
||||
pos = int(np.searchsorted(src_idx, bin_f))
|
||||
tpos.append(_SPEC_LOG_BINS - 1 - min(pos, _SPEC_LOG_BINS - 1))
|
||||
tlbl.append(f"{hz // 1000}k" if hz >= 1000 else str(hz))
|
||||
|
||||
vmin = float(np.percentile(spec, 2.0))
|
||||
vmax = float(np.percentile(spec, 99.5))
|
||||
|
||||
fig = Figure(figsize=(12, 3), dpi=120, tight_layout=True)
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
im = ax.imshow(spec, aspect="auto", cmap="inferno", origin="upper",
|
||||
vmin=vmin, vmax=vmax, interpolation="antialiased")
|
||||
ax.set_yticks(tpos)
|
||||
ax.set_yticklabels(tlbl, fontsize=8)
|
||||
ax.set_ylabel("Hz", fontsize=9)
|
||||
ax.set_xlabel("Time frames", fontsize=9)
|
||||
ax.set_title(path.stem, fontsize=9)
|
||||
fig.colorbar(im, ax=ax, label="dB", fraction=0.02, pad=0.01)
|
||||
|
||||
canvas = FigureCanvasAgg(fig)
|
||||
canvas.draw()
|
||||
canvas.print_figure(str(path.with_suffix(".png")), dpi=120)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loss curve rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -684,6 +750,10 @@ class SelvaLoraTrainer:
|
||||
import soundfile as sf
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||
try:
|
||||
_save_spectrogram(wav, sr, wav_path)
|
||||
except Exception as e:
|
||||
print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True)
|
||||
|
||||
last_step = step
|
||||
pbar_train.update(1)
|
||||
|
||||
Reference in New Issue
Block a user