feat(ti-trainer): generate baseline.wav once before training starts

Saves baseline.wav + baseline.png in the checkpoint dir using the same
seed as the TI eval samples — direct A/B comparison at every checkpoint
without re-generating the baseline each time.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 23:33:28 +02:00
parent e1a2f0ed7d
commit 0b24207ca5
+19
View File
@@ -30,6 +30,7 @@ from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.flow_matching import FlowMatching
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_smooth_losses,
@@ -300,6 +301,24 @@ class SelvaTextualInversionTrainer:
ckpt_dir = out_path.parent / out_path.stem
ckpt_dir.mkdir(parents=True, exist_ok=True)
# --- Baseline sample (once, before any training) ---
print(f"[TI Trainer] Generating baseline sample...", flush=True)
baseline_wav, baseline_sr = _eval_sample(
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
)
if baseline_wav is not None:
baseline_path = ckpt_dir / "baseline.wav"
try:
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
except RuntimeError:
import soundfile as sf
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
try:
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
except Exception:
pass
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()