feat: evaluate adapters on all dataset clips, not just clip_001

- _eval_sample gains clip_idx param (default 0, backward compatible)
- Evaluator loops over all dataset clips per adapter, saves one WAV per clip
- Reference metrics computed for all clips and averaged
- Comparison chart and summary use avg_metrics across all clips
- Eliminates bias from evaluating on an unrepresentative single clip

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 17:42:55 +02:00
parent 42ceb4b153
commit fdce9cbbf1
2 changed files with 102 additions and 77 deletions
+75 -49
View File
@@ -44,6 +44,16 @@ from .selva_lora_trainer import (
from selva_core.model.lora import apply_lora, load_lora
def _avg_metrics(metrics_list: list) -> dict:
"""Average spectral metrics across multiple clips, ignoring None entries."""
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
"spectral_flatness", "temporal_variance"]
valid = [m for m in metrics_list if m]
if not valid:
return {}
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
@@ -200,28 +210,40 @@ class SelvaLoraEvaluator:
seq_cfg = model["seq_cfg"]
# ------------------------------------------------------------------
# 3. Load reference audio (first clip, same one used for eval samples)
# 3. Collect reference metrics for all dataset clips
# ------------------------------------------------------------------
first_npz = sorted(data_dir.glob("*.npz"))[0]
audio_path = _find_audio(first_npz)
ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None,
"wav_path": None, "spectrogram_path": None,
"spectral_metrics": None, "status": "failed"}
if audio_path:
import shutil
npz_files = sorted(data_dir.glob("*.npz"))
ref_dir = output_dir / "reference"
ref_dir.mkdir(exist_ok=True)
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
flush=True)
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
continue
try:
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
ref_wav = ref_wav.unsqueeze(0) # [1, L]
import shutil
ref_out = output_dir / f"reference{audio_path.suffix}"
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
shutil.copy2(str(audio_path), str(ref_out))
ref_record["wav_path"] = str(ref_out)
ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
_save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference")
ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png"))
ref_record["status"] = "completed"
print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True)
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
ref_clips.append({
"clip": npz_path.stem,
"wav_path": str(ref_out),
"spectral_metrics": metrics,
})
except Exception as e:
print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True)
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
# Average reference metrics across all clips
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
print(f"[LoRA Evaluator] Reference avg — "
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
# ------------------------------------------------------------------
# 4. Build summary skeleton
@@ -232,10 +254,12 @@ class SelvaLoraEvaluator:
"completed_at": None,
"data_dir": str(data_dir),
"output_dir": str(output_dir),
"reference": first_npz.name,
"n_clips": len(ref_clips),
"steps": steps,
"seed": seed,
"adapters": [ref_record],
"reference_avg": ref_avg,
"reference_clips": ref_clips,
"adapters": [],
}
summary_path = output_dir / "eval_summary.json"
@@ -245,33 +269,33 @@ class SelvaLoraEvaluator:
_write_summary()
# ------------------------------------------------------------------
# 5. Per-adapter evaluation loop
# 5. Per-adapter evaluation loop (all clips)
# ------------------------------------------------------------------
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
n_clips = len(dataset)
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
for adapter_spec in spec["adapters"]:
adapter_id = adapter_spec["id"]
adapter_path = (adapter_spec.get("path") or "").strip()
safe_id = _safe_stem(adapter_id)
clip_dir = output_dir / safe_id
clip_dir.mkdir(exist_ok=True)
record = {
"id": adapter_id,
"path": adapter_path or None,
"meta": None,
"wav_path": None,
"spectrogram_path": None,
"spectral_metrics": None,
"clips": [],
"avg_metrics": None,
"status": "running",
}
print(f"[LoRA Evaluator] ── '{adapter_id}' ──", flush=True)
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
try:
with torch.inference_mode(False):
# 4a. Deep-copy generator
generator = copy.deepcopy(model["generator"])
# 4b. Apply + load LoRA if path given
if adapter_path:
pt_path = Path(adapter_path)
if not pt_path.is_absolute():
@@ -306,7 +330,6 @@ class SelvaLoraEvaluator:
else:
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
# 4c. Move to device and set sequence lengths
generator = generator.to(device, dtype)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
@@ -314,49 +337,53 @@ class SelvaLoraEvaluator:
sync_seq_len=seq_cfg.sync_seq_len,
)
# 4d. Run inference
clip_metrics_list = []
for clip_idx in range(n_clips):
clip_stem = npz_files[clip_idx].stem
wav, sr = _eval_sample(
generator, feature_utils_orig, dataset,
seq_cfg, device, dtype,
num_steps=steps, seed=seed,
num_steps=steps, seed=seed, clip_idx=clip_idx,
)
if wav is None:
raise RuntimeError("_eval_sample returned None")
pbar.update(1)
continue
# 4e. Save wav
wav_path = output_dir / f"{safe_id}.wav"
wav_path = clip_dir / f"{clip_stem}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
record["wav_path"] = str(wav_path)
print(f"[LoRA Evaluator] Saved {wav_path.name}", flush=True)
# 4f. Spectral metrics
metrics = _spectral_metrics(wav, sr)
record["spectral_metrics"] = metrics
print(f"[LoRA Evaluator] hf={metrics['hf_energy_ratio']:.3f} "
clip_metrics_list.append(metrics)
record["clips"].append({
"clip": clip_stem,
"wav_path": str(wav_path),
"spectral_metrics": metrics,
})
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"flatness={metrics['spectral_flatness']:.3f} "
f"tv={metrics['temporal_variance']:.3f}", flush=True)
# 4g. Spectrogram PNG
spec_path = output_dir / safe_id
_save_spectrogram(wav, sr, spec_path)
record["spectrogram_path"] = str(spec_path.with_suffix(".png"))
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
pbar.update(1)
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
record["status"] = "completed"
avg = record["avg_metrics"]
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
except Exception as e:
record["status"] = "failed"
record["error"] = str(e)
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
traceback.print_exc()
pbar.update(n_clips - len(record["clips"]))
finally:
# Free generator copy immediately — large model, many adapters
try:
del generator
except NameError:
@@ -365,7 +392,6 @@ class SelvaLoraEvaluator:
summary["adapters"].append(record)
_write_summary()
pbar.update(1)
# ------------------------------------------------------------------
# 5. Finalise summary
@@ -379,8 +405,8 @@ class SelvaLoraEvaluator:
# ------------------------------------------------------------------
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
if completed:
ids = [r["id"] for r in completed]
metrics_list = [r["spectral_metrics"] for r in completed]
ids = ["reference"] + [r["id"] for r in completed]
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
chart_path = output_dir / "metric_comparison.png"
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
+4 -5
View File
@@ -93,17 +93,16 @@ def _load_npz(path: Path) -> dict:
# ---------------------------------------------------------------------------
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
num_steps: int = 25, seed: int = 42):
"""Run a quick no-CFG inference pass on a fixed training clip.
num_steps: int = 25, seed: int = 42, clip_idx: int = 0):
"""Run a quick no-CFG inference pass on a training clip.
Always uses dataset[0] and a fixed noise seed so samples across checkpoints
Uses dataset[clip_idx] and a fixed noise seed so samples across checkpoints
are directly comparable — you can hear the model improve step by step.
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
Uses fewer ODE steps than inference (8 vs 25) for speed.
"""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[0]
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)