feat: save eval audio sample alongside each checkpoint
At every save_every steps, run a quick 8-step no-CFG inference pass on a random training clip and save the decoded waveform as sample_stepXXXXX.wav next to the checkpoint. Uses the existing generator.unnormalize + feature_utils.decode + vocode pipeline from the sampler. Failure is non-fatal (logged and skipped). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -74,6 +74,65 @@ def _load_npz(path: Path) -> dict:
|
||||
return bundle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Eval sample
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
||||
num_steps: int = 8):
|
||||
"""Run a quick no-CFG inference pass on a random training clip.
|
||||
|
||||
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
|
||||
Uses fewer ODE steps than inference (8 vs 25) for speed.
|
||||
"""
|
||||
generator.eval()
|
||||
try:
|
||||
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||
clip_f = clip_f_cpu.to(device, dtype)
|
||||
sync_f = sync_f_cpu.to(device, dtype)
|
||||
text_clip = text_clip_cpu.to(device, dtype)
|
||||
|
||||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||
|
||||
def velocity_fn(t, x):
|
||||
return generator.forward(x, clip_f, sync_f, text_clip,
|
||||
t.reshape(1).to(device, dtype))
|
||||
|
||||
with torch.no_grad():
|
||||
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
||||
x1_unnorm = generator.unnormalize(x1_pred)
|
||||
|
||||
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
|
||||
orig_device = next(feature_utils_orig.parameters()).device
|
||||
if orig_device != device:
|
||||
feature_utils_orig.to(device)
|
||||
try:
|
||||
spec = feature_utils_orig.decode(x1_unnorm)
|
||||
audio = feature_utils_orig.vocode(spec)
|
||||
finally:
|
||||
if orig_device != device:
|
||||
feature_utils_orig.to(orig_device)
|
||||
|
||||
audio = audio.float().cpu()
|
||||
if audio.dim() == 2:
|
||||
audio = audio.unsqueeze(1)
|
||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||
audio = audio.mean(dim=1, keepdim=True)
|
||||
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
|
||||
|
||||
except Exception as e:
|
||||
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
|
||||
return None, None
|
||||
finally:
|
||||
generator.train()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loss curve rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -406,6 +465,14 @@ class SelvaLoraTrainer:
|
||||
}, ckpt_path)
|
||||
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||||
|
||||
# Save a quick eval sample next to the checkpoint
|
||||
wav, sr = _eval_sample(generator, feature_utils_orig,
|
||||
dataset, seq_cfg, device, dtype)
|
||||
if wav is not None:
|
||||
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
||||
torchaudio.save(str(wav_path), wav, sr)
|
||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||
|
||||
pbar_train.update(1)
|
||||
|
||||
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
||||
|
||||
Reference in New Issue
Block a user