feat(bigvgan-trainer): add eval samples at checkpoints and end

Saves baseline.wav (ground truth roundtrip before training), stepN.wav
at each save_every checkpoint, and final.wav after training completes.
All use the same fixed reference segment (clip 0, position 0) for
direct comparison across checkpoints.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:30:34 +02:00
parent 790a53e3df
commit 9fdeb65182
+25
View File
@@ -175,6 +175,27 @@ class SelvaBigvganTrainer:
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
# Fixed reference segment for eval samples — always clip 0, start 0
ref_clip = clips[0][:segment_samples].to(device) # [T]
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
def _save_sample(label):
"""Vocode the reference mel and save as .wav."""
try:
with torch.no_grad():
wav = vocoder(ref_mel) # [1, 1, T] or [1, T]
if wav.dim() == 2:
wav = wav.unsqueeze(1)
wav = wav.float().cpu().clamp(-1, 1)
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate)
print(f"[BigVGAN] Sample saved: {wav_path}", flush=True)
except Exception as e:
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
# Baseline: ground truth roundtrip before any fine-tuning
_save_sample("baseline")
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
try: try:
@@ -230,6 +251,9 @@ class SelvaBigvganTrainer:
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
torch.save({"generator": vocoder.state_dict()}, str(step_path)) torch.save({"generator": vocoder.state_dict()}, str(step_path))
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
vocoder.eval()
_save_sample(f"step{step+1}")
vocoder.train()
finally: finally:
vocoder.requires_grad_(False) vocoder.requires_grad_(False)
@@ -240,4 +264,5 @@ class SelvaBigvganTrainer:
torch.save({"generator": vocoder.state_dict()}, str(out_path)) torch.save({"generator": vocoder.state_dict()}, str(out_path))
print(f"\n[BigVGAN] Saved: {out_path}", flush=True) print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
_save_sample("final")
return (str(out_path),) return (str(out_path),)