feat: add reference audio to LoRA evaluator

Loads the first clip's original audio (same clip used for inference),
copies it to output_dir/reference.wav, runs spectral metrics and
saves a spectrogram. Appears first in the comparison chart so generated
samples can be judged against the target sound.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 17:30:33 +02:00
parent dbfa7b23fe
commit 4505b89db1
+31 -4
View File
@@ -38,6 +38,8 @@ from .selva_lora_trainer import (
_spectral_metrics, _spectral_metrics,
_save_spectrogram, _save_spectrogram,
_pil_to_tensor, _pil_to_tensor,
_find_audio,
_load_audio,
) )
from selva_core.model.lora import apply_lora, load_lora from selva_core.model.lora import apply_lora, load_lora
@@ -199,7 +201,31 @@ class SelvaLoraEvaluator:
seq_cfg = model["seq_cfg"] seq_cfg = model["seq_cfg"]
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 3. Build summary skeleton # 3. Load reference audio (first clip, same one used for eval samples)
# ------------------------------------------------------------------
first_npz = sorted(data_dir.glob("*.npz"))[0]
audio_path = _find_audio(first_npz)
ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None,
"wav_path": None, "spectrogram_path": None,
"spectral_metrics": None, "status": "failed"}
if audio_path:
try:
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
ref_wav = ref_wav.unsqueeze(0) # [1, L]
ref_out = output_dir / "reference.wav"
import shutil
shutil.copy2(str(audio_path), str(ref_out))
ref_record["wav_path"] = str(ref_out)
ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
_save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference")
ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png"))
ref_record["status"] = "completed"
print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True)
except Exception as e:
print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True)
# ------------------------------------------------------------------
# 4. Build summary skeleton
# ------------------------------------------------------------------ # ------------------------------------------------------------------
summary = { summary = {
"name": name, "name": name,
@@ -207,9 +233,10 @@ class SelvaLoraEvaluator:
"completed_at": None, "completed_at": None,
"data_dir": str(data_dir), "data_dir": str(data_dir),
"output_dir": str(output_dir), "output_dir": str(output_dir),
"reference": first_npz.name,
"steps": steps, "steps": steps,
"seed": seed, "seed": seed,
"adapters": [], "adapters": [ref_record],
} }
summary_path = output_dir / "eval_summary.json" summary_path = output_dir / "eval_summary.json"
@@ -219,7 +246,7 @@ class SelvaLoraEvaluator:
_write_summary() _write_summary()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 4. Per-adapter evaluation loop # 5. Per-adapter evaluation loop
# ------------------------------------------------------------------ # ------------------------------------------------------------------
pbar = comfy.utils.ProgressBar(len(spec["adapters"])) pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
@@ -351,7 +378,7 @@ class SelvaLoraEvaluator:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 6. Comparison chart # 6. Comparison chart
# ------------------------------------------------------------------ # ------------------------------------------------------------------
completed = [r for r in summary["adapters"] if r["status"] == "completed"] completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
if completed: if completed:
ids = [r["id"] for r in completed] ids = [r["id"] for r in completed]
metrics_list = [r["spectral_metrics"] for r in completed] metrics_list = [r["spectral_metrics"] for r in completed]