feat: save mel spectrogram PNG alongside each eval sample

Adds _save_spectrogram() using PIL only (no matplotlib). Each _save_sample
call now writes both a .wav and a _spec.png so training progress is visible
without listening. Colour map is blue→green→yellow (viridis-ish), low
frequencies at the bottom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 03:03:28 +02:00
parent 0128a81cc2
commit 304d9d01bf
+37 -1
View File
@@ -23,6 +23,37 @@ import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).
Normalises to [0, 255], flips frequency axis so low freqs are at the bottom,
and saves as a greyscale PNG with a simple viridis-like colour map.
"""
try:
from PIL import Image
import numpy as np
mel = mel_tensor.squeeze(0).float().cpu().numpy() # [n_mels, T]
mel = mel[::-1] # low freq at bottom
lo, hi = mel.min(), mel.max()
if hi > lo:
mel = (mel - lo) / (hi - lo)
else:
mel = mel - lo
img_u8 = (mel * 255).clip(0, 255).astype(np.uint8)
# Simple blue→green→yellow colour map (viridis-ish) via LUT
lut_r = np.array([int(max(0, min(255, 255 * (v * 2 - 1)))) for v in np.linspace(0, 1, 256)], dtype=np.uint8)
lut_g = np.array([int(max(0, min(255, 255 * (1 - abs(v * 2 - 1))))) for v in np.linspace(0, 1, 256)], dtype=np.uint8)
lut_b = np.array([int(max(0, min(255, 255 * (1 - v * 2)))) for v in np.linspace(0, 1, 256)], dtype=np.uint8)
r = Image.fromarray(lut_r[img_u8])
g = Image.fromarray(lut_g[img_u8])
b = Image.fromarray(lut_b[img_u8])
Image.merge("RGB", (r, g, b)).save(str(path))
except Exception as e:
print(f"[BigVGAN] Spectrogram save failed: {e}", flush=True)
def _save_wav(path, wav_tensor, sample_rate): def _save_wav(path, wav_tensor, sample_rate):
"""Save [channels, samples] float32 tensor to .wav. """Save [channels, samples] float32 tensor to .wav.
@@ -288,8 +319,13 @@ def _do_train(vocoder, mel_converter, clips,
wav = wav.unsqueeze(1) wav = wav.unsqueeze(1)
wav = wav.float().cpu().clamp(-1, 1) wav = wav.float().cpu().clamp(-1, 1)
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
spec_path = out_path.parent / f"{out_path.stem}_{label}_spec.png"
_save_wav(wav_path, wav.squeeze(0), sample_rate) _save_wav(wav_path, wav.squeeze(0), sample_rate)
print(f"[BigVGAN] Sample saved: {wav_path}", flush=True) # Compute mel of the vocoded output for visual comparison
with torch.no_grad():
pred_mel = mel_converter(wav.squeeze(1).to(mel_converter.mel_basis.device))
_save_spectrogram(spec_path, pred_mel)
print(f"[BigVGAN] Sample saved: {wav_path} spec: {spec_path}", flush=True)
except Exception as e: except Exception as e:
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True) print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)