feat: add reference audio to LoRA evaluator
Loads the first clip's original audio (same clip used for inference), copies it to output_dir/reference.wav, runs spectral metrics and saves a spectrogram. Appears first in the comparison chart so generated samples can be judged against the target sound. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -38,6 +38,8 @@ from .selva_lora_trainer import (
|
|||||||
_spectral_metrics,
|
_spectral_metrics,
|
||||||
_save_spectrogram,
|
_save_spectrogram,
|
||||||
_pil_to_tensor,
|
_pil_to_tensor,
|
||||||
|
_find_audio,
|
||||||
|
_load_audio,
|
||||||
)
|
)
|
||||||
from selva_core.model.lora import apply_lora, load_lora
|
from selva_core.model.lora import apply_lora, load_lora
|
||||||
|
|
||||||
@@ -199,7 +201,31 @@ class SelvaLoraEvaluator:
|
|||||||
seq_cfg = model["seq_cfg"]
|
seq_cfg = model["seq_cfg"]
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 3. Build summary skeleton
|
# 3. Load reference audio (first clip, same one used for eval samples)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
first_npz = sorted(data_dir.glob("*.npz"))[0]
|
||||||
|
audio_path = _find_audio(first_npz)
|
||||||
|
ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None,
|
||||||
|
"wav_path": None, "spectrogram_path": None,
|
||||||
|
"spectral_metrics": None, "status": "failed"}
|
||||||
|
if audio_path:
|
||||||
|
try:
|
||||||
|
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
|
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
||||||
|
ref_out = output_dir / "reference.wav"
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(str(audio_path), str(ref_out))
|
||||||
|
ref_record["wav_path"] = str(ref_out)
|
||||||
|
ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
||||||
|
_save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference")
|
||||||
|
ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png"))
|
||||||
|
ref_record["status"] = "completed"
|
||||||
|
print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 4. Build summary skeleton
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
summary = {
|
summary = {
|
||||||
"name": name,
|
"name": name,
|
||||||
@@ -207,9 +233,10 @@ class SelvaLoraEvaluator:
|
|||||||
"completed_at": None,
|
"completed_at": None,
|
||||||
"data_dir": str(data_dir),
|
"data_dir": str(data_dir),
|
||||||
"output_dir": str(output_dir),
|
"output_dir": str(output_dir),
|
||||||
|
"reference": first_npz.name,
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"adapters": [],
|
"adapters": [ref_record],
|
||||||
}
|
}
|
||||||
summary_path = output_dir / "eval_summary.json"
|
summary_path = output_dir / "eval_summary.json"
|
||||||
|
|
||||||
@@ -219,7 +246,7 @@ class SelvaLoraEvaluator:
|
|||||||
_write_summary()
|
_write_summary()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 4. Per-adapter evaluation loop
|
# 5. Per-adapter evaluation loop
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
||||||
|
|
||||||
@@ -351,7 +378,7 @@ class SelvaLoraEvaluator:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 6. Comparison chart
|
# 6. Comparison chart
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
completed = [r for r in summary["adapters"] if r["status"] == "completed"]
|
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
||||||
if completed:
|
if completed:
|
||||||
ids = [r["id"] for r in completed]
|
ids = [r["id"] for r in completed]
|
||||||
metrics_list = [r["spectral_metrics"] for r in completed]
|
metrics_list = [r["spectral_metrics"] for r in completed]
|
||||||
|
|||||||
Reference in New Issue
Block a user