feat: evaluate adapters on all dataset clips, not just clip_001
- _eval_sample gains clip_idx param (default 0, backward compatible) - Evaluator loops over all dataset clips per adapter, saves one WAV per clip - Reference metrics computed for all clips and averaged - Comparison chart and summary use avg_metrics across all clips - Eliminates bias from evaluating on an unrepresentative single clip Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,16 @@ from .selva_lora_trainer import (
|
||||
from selva_core.model.lora import apply_lora, load_lora
|
||||
|
||||
|
||||
def _avg_metrics(metrics_list: list) -> dict:
|
||||
"""Average spectral metrics across multiple clips, ignoring None entries."""
|
||||
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
|
||||
"spectral_flatness", "temporal_variance"]
|
||||
valid = [m for m in metrics_list if m]
|
||||
if not valid:
|
||||
return {}
|
||||
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
|
||||
|
||||
|
||||
def _resolve_path(raw: str) -> Path:
|
||||
p = Path(raw.strip())
|
||||
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
||||
@@ -200,42 +210,56 @@ class SelvaLoraEvaluator:
|
||||
seq_cfg = model["seq_cfg"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Load reference audio (first clip, same one used for eval samples)
|
||||
# 3. Collect reference metrics for all dataset clips
|
||||
# ------------------------------------------------------------------
|
||||
first_npz = sorted(data_dir.glob("*.npz"))[0]
|
||||
audio_path = _find_audio(first_npz)
|
||||
ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None,
|
||||
"wav_path": None, "spectrogram_path": None,
|
||||
"spectral_metrics": None, "status": "failed"}
|
||||
if audio_path:
|
||||
import shutil
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
ref_dir = output_dir / "reference"
|
||||
ref_dir.mkdir(exist_ok=True)
|
||||
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
|
||||
|
||||
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
|
||||
flush=True)
|
||||
for npz_path in npz_files:
|
||||
audio_path = _find_audio(npz_path)
|
||||
if audio_path is None:
|
||||
continue
|
||||
try:
|
||||
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
||||
import shutil
|
||||
ref_out = output_dir / f"reference{audio_path.suffix}"
|
||||
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
|
||||
shutil.copy2(str(audio_path), str(ref_out))
|
||||
ref_record["wav_path"] = str(ref_out)
|
||||
ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
||||
_save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference")
|
||||
ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png"))
|
||||
ref_record["status"] = "completed"
|
||||
print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True)
|
||||
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
||||
ref_clips.append({
|
||||
"clip": npz_path.stem,
|
||||
"wav_path": str(ref_out),
|
||||
"spectral_metrics": metrics,
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True)
|
||||
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
|
||||
|
||||
# Average reference metrics across all clips
|
||||
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
|
||||
print(f"[LoRA Evaluator] Reference avg — "
|
||||
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
||||
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
|
||||
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Build summary skeleton
|
||||
# ------------------------------------------------------------------
|
||||
summary = {
|
||||
"name": name,
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": None,
|
||||
"data_dir": str(data_dir),
|
||||
"output_dir": str(output_dir),
|
||||
"reference": first_npz.name,
|
||||
"steps": steps,
|
||||
"seed": seed,
|
||||
"adapters": [ref_record],
|
||||
"name": name,
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": None,
|
||||
"data_dir": str(data_dir),
|
||||
"output_dir": str(output_dir),
|
||||
"n_clips": len(ref_clips),
|
||||
"steps": steps,
|
||||
"seed": seed,
|
||||
"reference_avg": ref_avg,
|
||||
"reference_clips": ref_clips,
|
||||
"adapters": [],
|
||||
}
|
||||
summary_path = output_dir / "eval_summary.json"
|
||||
|
||||
@@ -245,33 +269,33 @@ class SelvaLoraEvaluator:
|
||||
_write_summary()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Per-adapter evaluation loop
|
||||
# 5. Per-adapter evaluation loop (all clips)
|
||||
# ------------------------------------------------------------------
|
||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
||||
n_clips = len(dataset)
|
||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
|
||||
|
||||
for adapter_spec in spec["adapters"]:
|
||||
adapter_id = adapter_spec["id"]
|
||||
adapter_path = (adapter_spec.get("path") or "").strip()
|
||||
safe_id = _safe_stem(adapter_id)
|
||||
clip_dir = output_dir / safe_id
|
||||
clip_dir.mkdir(exist_ok=True)
|
||||
|
||||
record = {
|
||||
"id": adapter_id,
|
||||
"path": adapter_path or None,
|
||||
"meta": None,
|
||||
"wav_path": None,
|
||||
"spectrogram_path": None,
|
||||
"spectral_metrics": None,
|
||||
"status": "running",
|
||||
"id": adapter_id,
|
||||
"path": adapter_path or None,
|
||||
"meta": None,
|
||||
"clips": [],
|
||||
"avg_metrics": None,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
print(f"[LoRA Evaluator] ── '{adapter_id}' ──", flush=True)
|
||||
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
|
||||
|
||||
try:
|
||||
with torch.inference_mode(False):
|
||||
# 4a. Deep-copy generator
|
||||
generator = copy.deepcopy(model["generator"])
|
||||
|
||||
# 4b. Apply + load LoRA if path given
|
||||
if adapter_path:
|
||||
pt_path = Path(adapter_path)
|
||||
if not pt_path.is_absolute():
|
||||
@@ -306,7 +330,6 @@ class SelvaLoraEvaluator:
|
||||
else:
|
||||
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
||||
|
||||
# 4c. Move to device and set sequence lengths
|
||||
generator = generator.to(device, dtype)
|
||||
generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
@@ -314,49 +337,53 @@ class SelvaLoraEvaluator:
|
||||
sync_seq_len=seq_cfg.sync_seq_len,
|
||||
)
|
||||
|
||||
# 4d. Run inference
|
||||
wav, sr = _eval_sample(
|
||||
generator, feature_utils_orig, dataset,
|
||||
seq_cfg, device, dtype,
|
||||
num_steps=steps, seed=seed,
|
||||
)
|
||||
clip_metrics_list = []
|
||||
for clip_idx in range(n_clips):
|
||||
clip_stem = npz_files[clip_idx].stem
|
||||
wav, sr = _eval_sample(
|
||||
generator, feature_utils_orig, dataset,
|
||||
seq_cfg, device, dtype,
|
||||
num_steps=steps, seed=seed, clip_idx=clip_idx,
|
||||
)
|
||||
if wav is None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
if wav is None:
|
||||
raise RuntimeError("_eval_sample returned None")
|
||||
wav_path = clip_dir / f"{clip_stem}.wav"
|
||||
try:
|
||||
torchaudio.save(str(wav_path), wav, sr)
|
||||
except RuntimeError:
|
||||
import soundfile as sf
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
|
||||
# 4e. Save wav
|
||||
wav_path = output_dir / f"{safe_id}.wav"
|
||||
try:
|
||||
torchaudio.save(str(wav_path), wav, sr)
|
||||
except RuntimeError:
|
||||
import soundfile as sf
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
record["wav_path"] = str(wav_path)
|
||||
print(f"[LoRA Evaluator] Saved {wav_path.name}", flush=True)
|
||||
|
||||
# 4f. Spectral metrics
|
||||
metrics = _spectral_metrics(wav, sr)
|
||||
record["spectral_metrics"] = metrics
|
||||
print(f"[LoRA Evaluator] hf={metrics['hf_energy_ratio']:.3f} "
|
||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||
f"flatness={metrics['spectral_flatness']:.3f} "
|
||||
f"tv={metrics['temporal_variance']:.3f}", flush=True)
|
||||
|
||||
# 4g. Spectrogram PNG
|
||||
spec_path = output_dir / safe_id
|
||||
_save_spectrogram(wav, sr, spec_path)
|
||||
record["spectrogram_path"] = str(spec_path.with_suffix(".png"))
|
||||
metrics = _spectral_metrics(wav, sr)
|
||||
clip_metrics_list.append(metrics)
|
||||
record["clips"].append({
|
||||
"clip": clip_stem,
|
||||
"wav_path": str(wav_path),
|
||||
"spectral_metrics": metrics,
|
||||
})
|
||||
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
|
||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
|
||||
pbar.update(1)
|
||||
|
||||
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
|
||||
record["status"] = "completed"
|
||||
avg = record["avg_metrics"]
|
||||
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
|
||||
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
||||
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
|
||||
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
record["status"] = "failed"
|
||||
record["error"] = str(e)
|
||||
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
pbar.update(n_clips - len(record["clips"]))
|
||||
|
||||
finally:
|
||||
# Free generator copy immediately — large model, many adapters
|
||||
try:
|
||||
del generator
|
||||
except NameError:
|
||||
@@ -365,7 +392,6 @@ class SelvaLoraEvaluator:
|
||||
|
||||
summary["adapters"].append(record)
|
||||
_write_summary()
|
||||
pbar.update(1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Finalise summary
|
||||
@@ -379,8 +405,8 @@ class SelvaLoraEvaluator:
|
||||
# ------------------------------------------------------------------
|
||||
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
||||
if completed:
|
||||
ids = [r["id"] for r in completed]
|
||||
metrics_list = [r["spectral_metrics"] for r in completed]
|
||||
ids = ["reference"] + [r["id"] for r in completed]
|
||||
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
|
||||
chart_path = output_dir / "metric_comparison.png"
|
||||
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
||||
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user