feat: save spectrogram PNG alongside each eval sample
Log-frequency dB spectrogram (inferno colormap, 100Hz–16kHz) saved as step_XXXXX.png next to step_XXXXX.wav in samples/ subfolder. Makes high-frequency rolloff (low bitrate signature) immediately visible. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -149,6 +149,72 @@ def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
|||||||
generator.train()
|
generator.train()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Eval spectrogram rendering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SPEC_N_FFT = 2048
|
||||||
|
_SPEC_HOP = 512
|
||||||
|
_SPEC_DB_FLOOR = -80.0
|
||||||
|
_SPEC_LOG_BINS = 256
|
||||||
|
|
||||||
|
|
||||||
|
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
|
||||||
|
"""Save a log-frequency dB spectrogram PNG for an eval sample.
|
||||||
|
|
||||||
|
wav: [1, L] float32 CPU tensor (mono).
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||||
|
|
||||||
|
wav_np = wav.squeeze(0).numpy()
|
||||||
|
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||||||
|
window = torch.hann_window(_SPEC_N_FFT)
|
||||||
|
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||||||
|
window=window, return_complex=True)
|
||||||
|
mag = stft.abs().numpy()
|
||||||
|
db = 20.0 * np.log10(np.maximum(mag, 1e-8))
|
||||||
|
db = np.maximum(db, db.max() + _SPEC_DB_FLOOR).astype(np.float32)
|
||||||
|
|
||||||
|
# Log-frequency resampling
|
||||||
|
n_freqs = db.shape[0]
|
||||||
|
src_idx = np.logspace(0, np.log10(max(n_freqs - 1, 2)), _SPEC_LOG_BINS)
|
||||||
|
lo = np.floor(src_idx).astype(int).clip(0, n_freqs - 2)
|
||||||
|
frac = (src_idx - lo)[:, None]
|
||||||
|
spec = ((1 - frac) * db[lo] + frac * db[lo + 1]).astype(np.float32)
|
||||||
|
spec = spec[::-1] # low freq at bottom
|
||||||
|
|
||||||
|
# Y-tick positions (Hz labels)
|
||||||
|
tgt_hz = [100, 500, 1000, 2000, 4000, 8000, 16000]
|
||||||
|
tpos, tlbl = [], []
|
||||||
|
for hz in tgt_hz:
|
||||||
|
bin_f = hz * _SPEC_N_FFT / sr
|
||||||
|
if bin_f < 1 or bin_f >= n_freqs:
|
||||||
|
continue
|
||||||
|
pos = int(np.searchsorted(src_idx, bin_f))
|
||||||
|
tpos.append(_SPEC_LOG_BINS - 1 - min(pos, _SPEC_LOG_BINS - 1))
|
||||||
|
tlbl.append(f"{hz // 1000}k" if hz >= 1000 else str(hz))
|
||||||
|
|
||||||
|
vmin = float(np.percentile(spec, 2.0))
|
||||||
|
vmax = float(np.percentile(spec, 99.5))
|
||||||
|
|
||||||
|
fig = Figure(figsize=(12, 3), dpi=120, tight_layout=True)
|
||||||
|
ax = fig.add_subplot(1, 1, 1)
|
||||||
|
im = ax.imshow(spec, aspect="auto", cmap="inferno", origin="upper",
|
||||||
|
vmin=vmin, vmax=vmax, interpolation="antialiased")
|
||||||
|
ax.set_yticks(tpos)
|
||||||
|
ax.set_yticklabels(tlbl, fontsize=8)
|
||||||
|
ax.set_ylabel("Hz", fontsize=9)
|
||||||
|
ax.set_xlabel("Time frames", fontsize=9)
|
||||||
|
ax.set_title(path.stem, fontsize=9)
|
||||||
|
fig.colorbar(im, ax=ax, label="dB", fraction=0.02, pad=0.01)
|
||||||
|
|
||||||
|
canvas = FigureCanvasAgg(fig)
|
||||||
|
canvas.draw()
|
||||||
|
canvas.print_figure(str(path.with_suffix(".png")), dpi=120)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Loss curve rendering
|
# Loss curve rendering
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -684,6 +750,10 @@ class SelvaLoraTrainer:
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||||
|
try:
|
||||||
|
_save_spectrogram(wav, sr, wav_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[LoRA Trainer] Spectrogram failed: {e}", flush=True)
|
||||||
|
|
||||||
last_step = step
|
last_step = step
|
||||||
pbar_train.update(1)
|
pbar_train.update(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user