From 56c8d5d6b45eb6cc2054207b3e5394cfc21f16c9 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 21:47:02 +0200 Subject: [PATCH] feat: save eval audio sample alongside each checkpoint At every save_every steps, run a quick 8-step no-CFG inference pass on a random training clip and save the decoded waveform as sample_stepXXXXX.wav next to the checkpoint. Uses the existing generator.unnormalize + feature_utils.decode + vocode pipeline from the sampler. Failure is non-fatal (logged and skipped). Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_trainer.py | 67 +++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 866fc3b..2fa6d78 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -74,6 +74,65 @@ def _load_npz(path: Path) -> dict: return bundle +# --------------------------------------------------------------------------- +# Eval sample +# --------------------------------------------------------------------------- + +def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype, + num_steps: int = 8): + """Run a quick no-CFG inference pass on a random training clip. + + Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure. + Uses fewer ODE steps than inference (8 vs 25) for speed. + """ + generator.eval() + try: + _, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) + clip_f = clip_f_cpu.to(device, dtype) + sync_f = sync_f_cpu.to(device, dtype) + text_clip = text_clip_cpu.to(device, dtype) + + x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, + device=device, dtype=dtype) + + eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) + + def velocity_fn(t, x): + return generator.forward(x, clip_f, sync_f, text_clip, + t.reshape(1).to(device, dtype)) + + with torch.no_grad(): + x1_pred = eval_fm.to_data(velocity_fn, x0) + x1_unnorm = generator.unnormalize(x1_pred) + + # feature_utils_orig may be on CPU (offload strategy) — move temporarily + orig_device = next(feature_utils_orig.parameters()).device + if orig_device != device: + feature_utils_orig.to(device) + try: + spec = feature_utils_orig.decode(x1_unnorm) + audio = feature_utils_orig.vocode(spec) + finally: + if orig_device != device: + feature_utils_orig.to(orig_device) + + audio = audio.float().cpu() + if audio.dim() == 2: + audio = audio.unsqueeze(1) + elif audio.dim() == 3 and audio.shape[1] != 1: + audio = audio.mean(dim=1, keepdim=True) + + peak = audio.abs().max().clamp(min=1e-8) + audio = (audio / peak).clamp(-1, 1) + return audio.squeeze(0), seq_cfg.sampling_rate # [1, L] + + except Exception as e: + print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True) + return None, None + finally: + generator.train() + + # --------------------------------------------------------------------------- # Loss curve rendering # --------------------------------------------------------------------------- @@ -406,6 +465,14 @@ class SelvaLoraTrainer: }, ckpt_path) print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) + # Save a quick eval sample next to the checkpoint + wav, sr = _eval_sample(generator, feature_utils_orig, + dataset, seq_cfg, device, dtype) + if wav is not None: + wav_path = output_dir / f"sample_step{step:05d}.wav" + torchaudio.save(str(wav_path), wav, sr) + print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) + pbar_train.update(1) # Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)