feat(ti-trainer): generate baseline.wav once before training starts
Saves baseline.wav + baseline.png in the checkpoint dir using the same seed as the TI eval samples — direct A/B comparison at every checkpoint without re-generating the baseline each time. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|||||||
from selva_core.model.flow_matching import FlowMatching
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
from .selva_lora_trainer import (
|
from .selva_lora_trainer import (
|
||||||
_prepare_dataset,
|
_prepare_dataset,
|
||||||
|
_eval_sample,
|
||||||
_spectral_metrics,
|
_spectral_metrics,
|
||||||
_save_spectrogram,
|
_save_spectrogram,
|
||||||
_smooth_losses,
|
_smooth_losses,
|
||||||
@@ -300,6 +301,24 @@ class SelvaTextualInversionTrainer:
|
|||||||
ckpt_dir = out_path.parent / out_path.stem
|
ckpt_dir = out_path.parent / out_path.stem
|
||||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# --- Baseline sample (once, before any training) ---
|
||||||
|
print(f"[TI Trainer] Generating baseline sample...", flush=True)
|
||||||
|
baseline_wav, baseline_sr = _eval_sample(
|
||||||
|
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
|
||||||
|
)
|
||||||
|
if baseline_wav is not None:
|
||||||
|
baseline_path = ckpt_dir / "baseline.wav"
|
||||||
|
try:
|
||||||
|
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
|
||||||
|
except RuntimeError:
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
|
||||||
|
try:
|
||||||
|
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
|
||||||
|
|
||||||
# --- Training loop ---
|
# --- Training loop ---
|
||||||
generator.train()
|
generator.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|||||||
Reference in New Issue
Block a user