From 304d9d01bf64c631fde47dabac9059713b35e083 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 03:03:28 +0200 Subject: [PATCH] feat: save mel spectrogram PNG alongside each eval sample MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds _save_spectrogram() using PIL only (no matplotlib). Each _save_sample call now writes both a .wav and a _spec.png so training progress is visible without listening. Colour map is blue→green→yellow (viridis-ish), low frequencies at the bottom. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 40 ++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 9580355..22c955e 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -23,6 +23,37 @@ import folder_paths from .utils import SELVA_CATEGORY, get_device, soft_empty_cache +def _save_spectrogram(path, mel_tensor): + """Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep). + + Normalises to [0, 255], flips frequency axis so low freqs are at the bottom, + and saves as a greyscale PNG with a simple viridis-like colour map. + """ + try: + from PIL import Image + import numpy as np + + mel = mel_tensor.squeeze(0).float().cpu().numpy() # [n_mels, T] + mel = mel[::-1] # low freq at bottom + lo, hi = mel.min(), mel.max() + if hi > lo: + mel = (mel - lo) / (hi - lo) + else: + mel = mel - lo + img_u8 = (mel * 255).clip(0, 255).astype(np.uint8) + + # Simple blue→green→yellow colour map (viridis-ish) via LUT + lut_r = np.array([int(max(0, min(255, 255 * (v * 2 - 1)))) for v in np.linspace(0, 1, 256)], dtype=np.uint8) + lut_g = np.array([int(max(0, min(255, 255 * (1 - abs(v * 2 - 1))))) for v in np.linspace(0, 1, 256)], dtype=np.uint8) + lut_b = np.array([int(max(0, min(255, 255 * (1 - v * 2)))) for v in np.linspace(0, 1, 256)], dtype=np.uint8) + r = Image.fromarray(lut_r[img_u8]) + g = Image.fromarray(lut_g[img_u8]) + b = Image.fromarray(lut_b[img_u8]) + Image.merge("RGB", (r, g, b)).save(str(path)) + except Exception as e: + print(f"[BigVGAN] Spectrogram save failed: {e}", flush=True) + + def _save_wav(path, wav_tensor, sample_rate): """Save [channels, samples] float32 tensor to .wav. @@ -287,9 +318,14 @@ def _do_train(vocoder, mel_converter, clips, if wav.dim() == 2: wav = wav.unsqueeze(1) wav = wav.float().cpu().clamp(-1, 1) - wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" + wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" + spec_path = out_path.parent / f"{out_path.stem}_{label}_spec.png" _save_wav(wav_path, wav.squeeze(0), sample_rate) - print(f"[BigVGAN] Sample saved: {wav_path}", flush=True) + # Compute mel of the vocoded output for visual comparison + with torch.no_grad(): + pred_mel = mel_converter(wav.squeeze(1).to(mel_converter.mel_basis.device)) + _save_spectrogram(spec_path, pred_mel) + print(f"[BigVGAN] Sample saved: {wav_path} spec: {spec_path}", flush=True) except Exception as e: print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)