feat: save eval audio sample alongside each checkpoint

At every save_every steps, run a quick 8-step no-CFG inference pass on
a random training clip and save the decoded waveform as
sample_stepXXXXX.wav next to the checkpoint. Uses the existing
generator.unnormalize + feature_utils.decode + vocode pipeline from
the sampler. Failure is non-fatal (logged and skipped).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 21:47:02 +02:00
parent b430953602
commit 56c8d5d6b4
+67
View File
@@ -74,6 +74,65 @@ def _load_npz(path: Path) -> dict:
return bundle
# ---------------------------------------------------------------------------
# Eval sample
# ---------------------------------------------------------------------------
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
num_steps: int = 8):
"""Run a quick no-CFG inference pass on a random training clip.
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
Uses fewer ODE steps than inference (8 vs 25) for speed.
"""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_clip,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
orig_device = next(feature_utils_orig.parameters()).device
if orig_device != device:
feature_utils_orig.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
if orig_device != device:
feature_utils_orig.to(orig_device)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
except Exception as e:
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Loss curve rendering
# ---------------------------------------------------------------------------
@@ -406,6 +465,14 @@ class SelvaLoraTrainer:
}, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample next to the checkpoint
wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype)
if wav is not None:
wav_path = output_dir / f"sample_step{step:05d}.wav"
torchaudio.save(str(wav_path), wav, sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
pbar_train.update(1)
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)