feat: add reference audio to LoRA evaluator
Loads the first clip's original audio (same clip used for inference), copies it to output_dir/reference.wav, runs spectral metrics and saves a spectrogram. Appears first in the comparison chart so generated samples can be judged against the target sound. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -38,6 +38,8 @@ from .selva_lora_trainer import (
|
||||
_spectral_metrics,
|
||||
_save_spectrogram,
|
||||
_pil_to_tensor,
|
||||
_find_audio,
|
||||
_load_audio,
|
||||
)
|
||||
from selva_core.model.lora import apply_lora, load_lora
|
||||
|
||||
@@ -199,7 +201,31 @@ class SelvaLoraEvaluator:
|
||||
seq_cfg = model["seq_cfg"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Build summary skeleton
|
||||
# 3. Load reference audio (first clip, same one used for eval samples)
|
||||
# ------------------------------------------------------------------
|
||||
first_npz = sorted(data_dir.glob("*.npz"))[0]
|
||||
audio_path = _find_audio(first_npz)
|
||||
ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None,
|
||||
"wav_path": None, "spectrogram_path": None,
|
||||
"spectral_metrics": None, "status": "failed"}
|
||||
if audio_path:
|
||||
try:
|
||||
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
||||
ref_out = output_dir / "reference.wav"
|
||||
import shutil
|
||||
shutil.copy2(str(audio_path), str(ref_out))
|
||||
ref_record["wav_path"] = str(ref_out)
|
||||
ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
||||
_save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference")
|
||||
ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png"))
|
||||
ref_record["status"] = "completed"
|
||||
print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Build summary skeleton
|
||||
# ------------------------------------------------------------------
|
||||
summary = {
|
||||
"name": name,
|
||||
@@ -207,9 +233,10 @@ class SelvaLoraEvaluator:
|
||||
"completed_at": None,
|
||||
"data_dir": str(data_dir),
|
||||
"output_dir": str(output_dir),
|
||||
"reference": first_npz.name,
|
||||
"steps": steps,
|
||||
"seed": seed,
|
||||
"adapters": [],
|
||||
"adapters": [ref_record],
|
||||
}
|
||||
summary_path = output_dir / "eval_summary.json"
|
||||
|
||||
@@ -219,7 +246,7 @@ class SelvaLoraEvaluator:
|
||||
_write_summary()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Per-adapter evaluation loop
|
||||
# 5. Per-adapter evaluation loop
|
||||
# ------------------------------------------------------------------
|
||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
||||
|
||||
@@ -351,7 +378,7 @@ class SelvaLoraEvaluator:
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Comparison chart
|
||||
# ------------------------------------------------------------------
|
||||
completed = [r for r in summary["adapters"] if r["status"] == "completed"]
|
||||
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
||||
if completed:
|
||||
ids = [r["id"] for r in completed]
|
||||
metrics_list = [r["spectral_metrics"] for r in completed]
|
||||
|
||||
Reference in New Issue
Block a user