diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index d1b9911..f88820b 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -88,21 +88,24 @@ def _load_npz(path: Path) -> dict: # --------------------------------------------------------------------------- def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype, - num_steps: int = 8): - """Run a quick no-CFG inference pass on a random training clip. + num_steps: int = 8, seed: int = 0): + """Run a quick no-CFG inference pass on a fixed training clip. + Always uses dataset[0] and a fixed noise seed so samples across checkpoints + are directly comparable — you can hear the model improve step by step. Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure. Uses fewer ODE steps than inference (8 vs 25) for speed. """ generator.eval() try: - _, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) + _, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[0] clip_f = clip_f_cpu.to(device, dtype) sync_f = sync_f_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype) + rng = torch.Generator(device=device).manual_seed(seed) x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, - device=device, dtype=dtype) + device=device, dtype=dtype, generator=rng) eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) @@ -664,11 +667,13 @@ class SelvaLoraTrainer: }, ckpt_path) print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) - # Save a quick eval sample next to the checkpoint + # Save a quick eval sample in samples/ subfolder + samples_dir = output_dir / "samples" + samples_dir.mkdir(exist_ok=True) wav, sr = _eval_sample(generator, feature_utils_orig, - dataset, seq_cfg, device, dtype) + dataset, seq_cfg, device, dtype, seed=seed) if wav is not None: - wav_path = output_dir / f"sample_step{step:05d}.wav" + wav_path = samples_dir / f"step_{step:05d}.wav" try: torchaudio.save(str(wav_path), wav, sr) except RuntimeError: