From 0b24207ca5d6581fa15f5825ac6113c0e745b273 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 23:33:28 +0200 Subject: [PATCH] feat(ti-trainer): generate baseline.wav once before training starts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Saves baseline.wav + baseline.png in the checkpoint dir using the same seed as the TI eval samples — direct A/B comparison at every checkpoint without re-generating the baseline each time. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_textual_inversion_trainer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/nodes/selva_textual_inversion_trainer.py b/nodes/selva_textual_inversion_trainer.py index 8c2e7d7..92cb5eb 100644 --- a/nodes/selva_textual_inversion_trainer.py +++ b/nodes/selva_textual_inversion_trainer.py @@ -30,6 +30,7 @@ from .utils import SELVA_CATEGORY, get_device, soft_empty_cache from selva_core.model.flow_matching import FlowMatching from .selva_lora_trainer import ( _prepare_dataset, + _eval_sample, _spectral_metrics, _save_spectrogram, _smooth_losses, @@ -300,6 +301,24 @@ class SelvaTextualInversionTrainer: ckpt_dir = out_path.parent / out_path.stem ckpt_dir.mkdir(parents=True, exist_ok=True) + # --- Baseline sample (once, before any training) --- + print(f"[TI Trainer] Generating baseline sample...", flush=True) + baseline_wav, baseline_sr = _eval_sample( + generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed, + ) + if baseline_wav is not None: + baseline_path = ckpt_dir / "baseline.wav" + try: + torchaudio.save(str(baseline_path), baseline_wav, baseline_sr) + except RuntimeError: + import soundfile as sf + sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr) + try: + _save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png") + except Exception: + pass + print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True) + # --- Training loop --- generator.train() optimizer.zero_grad()