fix: eval samples use fixed clip/seed, save to samples/ subfolder

- Always sample dataset[0] with fixed noise seed so checkpoints are
  directly comparable (hear the model improve step by step)
- Save to output_dir/samples/step_XXXXX.wav instead of alongside checkpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 00:54:37 +02:00
parent 0682a536cb
commit f15e02b0b8
+12 -7
View File
@@ -88,21 +88,24 @@ def _load_npz(path: Path) -> dict:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype, def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
num_steps: int = 8): num_steps: int = 8, seed: int = 0):
"""Run a quick no-CFG inference pass on a random training clip. """Run a quick no-CFG inference pass on a fixed training clip.
Always uses dataset[0] and a fixed noise seed so samples across checkpoints
are directly comparable — you can hear the model improve step by step.
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure. Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
Uses fewer ODE steps than inference (8 vs 25) for speed. Uses fewer ODE steps than inference (8 vs 25) for speed.
""" """
generator.eval() generator.eval()
try: try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) _, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[0]
clip_f = clip_f_cpu.to(device, dtype) clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype) sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype)
rng = torch.Generator(device=device).manual_seed(seed)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype) device=device, dtype=dtype, generator=rng)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
@@ -664,11 +667,13 @@ class SelvaLoraTrainer:
}, ckpt_path) }, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample next to the checkpoint # Save a quick eval sample in samples/ subfolder
samples_dir = output_dir / "samples"
samples_dir.mkdir(exist_ok=True)
wav, sr = _eval_sample(generator, feature_utils_orig, wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype) dataset, seq_cfg, device, dtype, seed=seed)
if wav is not None: if wav is not None:
wav_path = output_dir / f"sample_step{step:05d}.wav" wav_path = samples_dir / f"step_{step:05d}.wav"
try: try:
torchaudio.save(str(wav_path), wav, sr) torchaudio.save(str(wav_path), wav, sr)
except RuntimeError: except RuntimeError: