feat: save mel spectrogram PNG alongside each eval sample
Adds _save_spectrogram() using PIL only (no matplotlib). Each _save_sample call now writes both a .wav and a _spec.png so training progress is visible without listening. Colour map is blue→green→yellow (viridis-ish), low frequencies at the bottom. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -23,6 +23,37 @@ import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||
|
||||
def _save_spectrogram(path, mel_tensor):
|
||||
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).
|
||||
|
||||
Normalises to [0, 255], flips frequency axis so low freqs are at the bottom,
|
||||
and saves as a greyscale PNG with a simple viridis-like colour map.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
mel = mel_tensor.squeeze(0).float().cpu().numpy() # [n_mels, T]
|
||||
mel = mel[::-1] # low freq at bottom
|
||||
lo, hi = mel.min(), mel.max()
|
||||
if hi > lo:
|
||||
mel = (mel - lo) / (hi - lo)
|
||||
else:
|
||||
mel = mel - lo
|
||||
img_u8 = (mel * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
# Simple blue→green→yellow colour map (viridis-ish) via LUT
|
||||
lut_r = np.array([int(max(0, min(255, 255 * (v * 2 - 1)))) for v in np.linspace(0, 1, 256)], dtype=np.uint8)
|
||||
lut_g = np.array([int(max(0, min(255, 255 * (1 - abs(v * 2 - 1))))) for v in np.linspace(0, 1, 256)], dtype=np.uint8)
|
||||
lut_b = np.array([int(max(0, min(255, 255 * (1 - v * 2)))) for v in np.linspace(0, 1, 256)], dtype=np.uint8)
|
||||
r = Image.fromarray(lut_r[img_u8])
|
||||
g = Image.fromarray(lut_g[img_u8])
|
||||
b = Image.fromarray(lut_b[img_u8])
|
||||
Image.merge("RGB", (r, g, b)).save(str(path))
|
||||
except Exception as e:
|
||||
print(f"[BigVGAN] Spectrogram save failed: {e}", flush=True)
|
||||
|
||||
|
||||
def _save_wav(path, wav_tensor, sample_rate):
|
||||
"""Save [channels, samples] float32 tensor to .wav.
|
||||
|
||||
@@ -287,9 +318,14 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
if wav.dim() == 2:
|
||||
wav = wav.unsqueeze(1)
|
||||
wav = wav.float().cpu().clamp(-1, 1)
|
||||
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
|
||||
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
|
||||
spec_path = out_path.parent / f"{out_path.stem}_{label}_spec.png"
|
||||
_save_wav(wav_path, wav.squeeze(0), sample_rate)
|
||||
print(f"[BigVGAN] Sample saved: {wav_path}", flush=True)
|
||||
# Compute mel of the vocoded output for visual comparison
|
||||
with torch.no_grad():
|
||||
pred_mel = mel_converter(wav.squeeze(1).to(mel_converter.mel_basis.device))
|
||||
_save_spectrogram(spec_path, pred_mel)
|
||||
print(f"[BigVGAN] Sample saved: {wav_path} spec: {spec_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user