feat(bigvgan-trainer): add eval samples at checkpoints and end
Saves baseline.wav (ground truth roundtrip before training), stepN.wav at each save_every checkpoint, and final.wav after training completes. All use the same fixed reference segment (clip 0, position 0) for direct comparison across checkpoints. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -175,6 +175,27 @@ class SelvaBigvganTrainer:
|
|||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Fixed reference segment for eval samples — always clip 0, start 0
|
||||||
|
ref_clip = clips[0][:segment_samples].to(device) # [T]
|
||||||
|
ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||||
|
|
||||||
|
def _save_sample(label):
|
||||||
|
"""Vocode the reference mel and save as .wav."""
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
wav = vocoder(ref_mel) # [1, 1, T] or [1, T]
|
||||||
|
if wav.dim() == 2:
|
||||||
|
wav = wav.unsqueeze(1)
|
||||||
|
wav = wav.float().cpu().clamp(-1, 1)
|
||||||
|
wav_path = out_path.parent / f"{out_path.stem}_{label}.wav"
|
||||||
|
torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate)
|
||||||
|
print(f"[BigVGAN] Sample saved: {wav_path}", flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True)
|
||||||
|
|
||||||
|
# Baseline: ground truth roundtrip before any fine-tuning
|
||||||
|
_save_sample("baseline")
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -230,6 +251,9 @@ class SelvaBigvganTrainer:
|
|||||||
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
||||||
torch.save({"generator": vocoder.state_dict()}, str(step_path))
|
torch.save({"generator": vocoder.state_dict()}, str(step_path))
|
||||||
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
print(f"[BigVGAN] Checkpoint: {step_path}", flush=True)
|
||||||
|
vocoder.eval()
|
||||||
|
_save_sample(f"step{step+1}")
|
||||||
|
vocoder.train()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
vocoder.requires_grad_(False)
|
vocoder.requires_grad_(False)
|
||||||
@@ -240,4 +264,5 @@ class SelvaBigvganTrainer:
|
|||||||
|
|
||||||
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
torch.save({"generator": vocoder.state_dict()}, str(out_path))
|
||||||
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
print(f"\n[BigVGAN] Saved: {out_path}", flush=True)
|
||||||
|
_save_sample("final")
|
||||||
return (str(out_path),)
|
return (str(out_path),)
|
||||||
|
|||||||
Reference in New Issue
Block a user