feat: save eval audio sample alongside each checkpoint
At every save_every steps, run a quick 8-step no-CFG inference pass on a random training clip and save the decoded waveform as sample_stepXXXXX.wav next to the checkpoint. Uses the existing generator.unnormalize + feature_utils.decode + vocode pipeline from the sampler. Failure is non-fatal (logged and skipped). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -74,6 +74,65 @@ def _load_npz(path: Path) -> dict:
|
|||||||
return bundle
|
return bundle
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Eval sample
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
||||||
|
num_steps: int = 8):
|
||||||
|
"""Run a quick no-CFG inference pass on a random training clip.
|
||||||
|
|
||||||
|
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
|
||||||
|
Uses fewer ODE steps than inference (8 vs 25) for speed.
|
||||||
|
"""
|
||||||
|
generator.eval()
|
||||||
|
try:
|
||||||
|
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||||
|
clip_f = clip_f_cpu.to(device, dtype)
|
||||||
|
sync_f = sync_f_cpu.to(device, dtype)
|
||||||
|
text_clip = text_clip_cpu.to(device, dtype)
|
||||||
|
|
||||||
|
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
|
||||||
|
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||||
|
|
||||||
|
def velocity_fn(t, x):
|
||||||
|
return generator.forward(x, clip_f, sync_f, text_clip,
|
||||||
|
t.reshape(1).to(device, dtype))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
||||||
|
x1_unnorm = generator.unnormalize(x1_pred)
|
||||||
|
|
||||||
|
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
|
||||||
|
orig_device = next(feature_utils_orig.parameters()).device
|
||||||
|
if orig_device != device:
|
||||||
|
feature_utils_orig.to(device)
|
||||||
|
try:
|
||||||
|
spec = feature_utils_orig.decode(x1_unnorm)
|
||||||
|
audio = feature_utils_orig.vocode(spec)
|
||||||
|
finally:
|
||||||
|
if orig_device != device:
|
||||||
|
feature_utils_orig.to(orig_device)
|
||||||
|
|
||||||
|
audio = audio.float().cpu()
|
||||||
|
if audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(1)
|
||||||
|
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||||
|
audio = audio.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
|
||||||
|
return None, None
|
||||||
|
finally:
|
||||||
|
generator.train()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Loss curve rendering
|
# Loss curve rendering
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -406,6 +465,14 @@ class SelvaLoraTrainer:
|
|||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||||||
|
|
||||||
|
# Save a quick eval sample next to the checkpoint
|
||||||
|
wav, sr = _eval_sample(generator, feature_utils_orig,
|
||||||
|
dataset, seq_cfg, device, dtype)
|
||||||
|
if wav is not None:
|
||||||
|
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
||||||
|
torchaudio.save(str(wav_path), wav, sr)
|
||||||
|
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||||
|
|
||||||
pbar_train.update(1)
|
pbar_train.update(1)
|
||||||
|
|
||||||
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
||||||
|
|||||||
Reference in New Issue
Block a user