diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 095ef76..5ce63cc 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -175,6 +175,27 @@ class SelvaBigvganTrainer: torch.manual_seed(seed) random.seed(seed) + # Fixed reference segment for eval samples — always clip 0, start 0 + ref_clip = clips[0][:segment_samples].to(device) # [T] + ref_mel = mel_converter(ref_clip.unsqueeze(0)) # [1, n_mels, T_mel] + + def _save_sample(label): + """Vocode the reference mel and save as .wav.""" + try: + with torch.no_grad(): + wav = vocoder(ref_mel) # [1, 1, T] or [1, T] + if wav.dim() == 2: + wav = wav.unsqueeze(1) + wav = wav.float().cpu().clamp(-1, 1) + wav_path = out_path.parent / f"{out_path.stem}_{label}.wav" + torchaudio.save(str(wav_path), wav.squeeze(0), sample_rate) + print(f"[BigVGAN] Sample saved: {wav_path}", flush=True) + except Exception as e: + print(f"[BigVGAN] Sample save failed ({label}): {e}", flush=True) + + # Baseline: ground truth roundtrip before any fine-tuning + _save_sample("baseline") + pbar = comfy.utils.ProgressBar(steps) try: @@ -230,6 +251,9 @@ class SelvaBigvganTrainer: step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" torch.save({"generator": vocoder.state_dict()}, str(step_path)) print(f"[BigVGAN] Checkpoint: {step_path}", flush=True) + vocoder.eval() + _save_sample(f"step{step+1}") + vocoder.train() finally: vocoder.requires_grad_(False) @@ -240,4 +264,5 @@ class SelvaBigvganTrainer: torch.save({"generator": vocoder.state_dict()}, str(out_path)) print(f"\n[BigVGAN] Saved: {out_path}", flush=True) + _save_sample("final") return (str(out_path),)