feat(ti-trainer): generate baseline.wav once before training starts
Saves baseline.wav + baseline.png in the checkpoint dir using the same seed as the TI eval samples — direct A/B comparison at every checkpoint without re-generating the baseline each time. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
from .selva_lora_trainer import (
|
||||
_prepare_dataset,
|
||||
_eval_sample,
|
||||
_spectral_metrics,
|
||||
_save_spectrogram,
|
||||
_smooth_losses,
|
||||
@@ -300,6 +301,24 @@ class SelvaTextualInversionTrainer:
|
||||
ckpt_dir = out_path.parent / out_path.stem
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# --- Baseline sample (once, before any training) ---
|
||||
print(f"[TI Trainer] Generating baseline sample...", flush=True)
|
||||
baseline_wav, baseline_sr = _eval_sample(
|
||||
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
|
||||
)
|
||||
if baseline_wav is not None:
|
||||
baseline_path = ckpt_dir / "baseline.wav"
|
||||
try:
|
||||
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
|
||||
except RuntimeError:
|
||||
import soundfile as sf
|
||||
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
|
||||
try:
|
||||
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
|
||||
except Exception:
|
||||
pass
|
||||
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
|
||||
|
||||
# --- Training loop ---
|
||||
generator.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
Reference in New Issue
Block a user