feat: evaluate adapters on all dataset clips, not just clip_001
- _eval_sample gains clip_idx param (default 0, backward compatible) - Evaluator loops over all dataset clips per adapter, saves one WAV per clip - Reference metrics computed for all clips and averaged - Comparison chart and summary use avg_metrics across all clips - Eliminates bias from evaluating on an unrepresentative single clip Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,16 @@ from .selva_lora_trainer import (
|
|||||||
from selva_core.model.lora import apply_lora, load_lora
|
from selva_core.model.lora import apply_lora, load_lora
|
||||||
|
|
||||||
|
|
||||||
|
def _avg_metrics(metrics_list: list) -> dict:
|
||||||
|
"""Average spectral metrics across multiple clips, ignoring None entries."""
|
||||||
|
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
|
||||||
|
"spectral_flatness", "temporal_variance"]
|
||||||
|
valid = [m for m in metrics_list if m]
|
||||||
|
if not valid:
|
||||||
|
return {}
|
||||||
|
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
def _resolve_path(raw: str) -> Path:
|
||||||
p = Path(raw.strip())
|
p = Path(raw.strip())
|
||||||
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
||||||
@@ -200,42 +210,56 @@ class SelvaLoraEvaluator:
|
|||||||
seq_cfg = model["seq_cfg"]
|
seq_cfg = model["seq_cfg"]
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 3. Load reference audio (first clip, same one used for eval samples)
|
# 3. Collect reference metrics for all dataset clips
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
first_npz = sorted(data_dir.glob("*.npz"))[0]
|
import shutil
|
||||||
audio_path = _find_audio(first_npz)
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None,
|
ref_dir = output_dir / "reference"
|
||||||
"wav_path": None, "spectrogram_path": None,
|
ref_dir.mkdir(exist_ok=True)
|
||||||
"spectral_metrics": None, "status": "failed"}
|
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
|
||||||
if audio_path:
|
|
||||||
|
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
|
||||||
|
flush=True)
|
||||||
|
for npz_path in npz_files:
|
||||||
|
audio_path = _find_audio(npz_path)
|
||||||
|
if audio_path is None:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
||||||
import shutil
|
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
|
||||||
ref_out = output_dir / f"reference{audio_path.suffix}"
|
|
||||||
shutil.copy2(str(audio_path), str(ref_out))
|
shutil.copy2(str(audio_path), str(ref_out))
|
||||||
ref_record["wav_path"] = str(ref_out)
|
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
||||||
ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
ref_clips.append({
|
||||||
_save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference")
|
"clip": npz_path.stem,
|
||||||
ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png"))
|
"wav_path": str(ref_out),
|
||||||
ref_record["status"] = "completed"
|
"spectral_metrics": metrics,
|
||||||
print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True)
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True)
|
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
|
||||||
|
|
||||||
|
# Average reference metrics across all clips
|
||||||
|
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
|
||||||
|
print(f"[LoRA Evaluator] Reference avg — "
|
||||||
|
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
||||||
|
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
|
||||||
|
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 4. Build summary skeleton
|
# 4. Build summary skeleton
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
summary = {
|
summary = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||||
"completed_at": None,
|
"completed_at": None,
|
||||||
"data_dir": str(data_dir),
|
"data_dir": str(data_dir),
|
||||||
"output_dir": str(output_dir),
|
"output_dir": str(output_dir),
|
||||||
"reference": first_npz.name,
|
"n_clips": len(ref_clips),
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"adapters": [ref_record],
|
"reference_avg": ref_avg,
|
||||||
|
"reference_clips": ref_clips,
|
||||||
|
"adapters": [],
|
||||||
}
|
}
|
||||||
summary_path = output_dir / "eval_summary.json"
|
summary_path = output_dir / "eval_summary.json"
|
||||||
|
|
||||||
@@ -245,33 +269,33 @@ class SelvaLoraEvaluator:
|
|||||||
_write_summary()
|
_write_summary()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 5. Per-adapter evaluation loop
|
# 5. Per-adapter evaluation loop (all clips)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
n_clips = len(dataset)
|
||||||
|
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
|
||||||
|
|
||||||
for adapter_spec in spec["adapters"]:
|
for adapter_spec in spec["adapters"]:
|
||||||
adapter_id = adapter_spec["id"]
|
adapter_id = adapter_spec["id"]
|
||||||
adapter_path = (adapter_spec.get("path") or "").strip()
|
adapter_path = (adapter_spec.get("path") or "").strip()
|
||||||
safe_id = _safe_stem(adapter_id)
|
safe_id = _safe_stem(adapter_id)
|
||||||
|
clip_dir = output_dir / safe_id
|
||||||
|
clip_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
record = {
|
record = {
|
||||||
"id": adapter_id,
|
"id": adapter_id,
|
||||||
"path": adapter_path or None,
|
"path": adapter_path or None,
|
||||||
"meta": None,
|
"meta": None,
|
||||||
"wav_path": None,
|
"clips": [],
|
||||||
"spectrogram_path": None,
|
"avg_metrics": None,
|
||||||
"spectral_metrics": None,
|
"status": "running",
|
||||||
"status": "running",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"[LoRA Evaluator] ── '{adapter_id}' ──", flush=True)
|
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
# 4a. Deep-copy generator
|
|
||||||
generator = copy.deepcopy(model["generator"])
|
generator = copy.deepcopy(model["generator"])
|
||||||
|
|
||||||
# 4b. Apply + load LoRA if path given
|
|
||||||
if adapter_path:
|
if adapter_path:
|
||||||
pt_path = Path(adapter_path)
|
pt_path = Path(adapter_path)
|
||||||
if not pt_path.is_absolute():
|
if not pt_path.is_absolute():
|
||||||
@@ -306,7 +330,6 @@ class SelvaLoraEvaluator:
|
|||||||
else:
|
else:
|
||||||
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
||||||
|
|
||||||
# 4c. Move to device and set sequence lengths
|
|
||||||
generator = generator.to(device, dtype)
|
generator = generator.to(device, dtype)
|
||||||
generator.update_seq_lengths(
|
generator.update_seq_lengths(
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
@@ -314,49 +337,53 @@ class SelvaLoraEvaluator:
|
|||||||
sync_seq_len=seq_cfg.sync_seq_len,
|
sync_seq_len=seq_cfg.sync_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4d. Run inference
|
clip_metrics_list = []
|
||||||
wav, sr = _eval_sample(
|
for clip_idx in range(n_clips):
|
||||||
generator, feature_utils_orig, dataset,
|
clip_stem = npz_files[clip_idx].stem
|
||||||
seq_cfg, device, dtype,
|
wav, sr = _eval_sample(
|
||||||
num_steps=steps, seed=seed,
|
generator, feature_utils_orig, dataset,
|
||||||
)
|
seq_cfg, device, dtype,
|
||||||
|
num_steps=steps, seed=seed, clip_idx=clip_idx,
|
||||||
|
)
|
||||||
|
if wav is None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
if wav is None:
|
wav_path = clip_dir / f"{clip_stem}.wav"
|
||||||
raise RuntimeError("_eval_sample returned None")
|
try:
|
||||||
|
torchaudio.save(str(wav_path), wav, sr)
|
||||||
|
except RuntimeError:
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||||
|
|
||||||
# 4e. Save wav
|
metrics = _spectral_metrics(wav, sr)
|
||||||
wav_path = output_dir / f"{safe_id}.wav"
|
clip_metrics_list.append(metrics)
|
||||||
try:
|
record["clips"].append({
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
"clip": clip_stem,
|
||||||
except RuntimeError:
|
"wav_path": str(wav_path),
|
||||||
import soundfile as sf
|
"spectral_metrics": metrics,
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
})
|
||||||
record["wav_path"] = str(wav_path)
|
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
|
||||||
print(f"[LoRA Evaluator] Saved {wav_path.name}", flush=True)
|
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||||
|
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
|
||||||
# 4f. Spectral metrics
|
pbar.update(1)
|
||||||
metrics = _spectral_metrics(wav, sr)
|
|
||||||
record["spectral_metrics"] = metrics
|
|
||||||
print(f"[LoRA Evaluator] hf={metrics['hf_energy_ratio']:.3f} "
|
|
||||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
|
||||||
f"flatness={metrics['spectral_flatness']:.3f} "
|
|
||||||
f"tv={metrics['temporal_variance']:.3f}", flush=True)
|
|
||||||
|
|
||||||
# 4g. Spectrogram PNG
|
|
||||||
spec_path = output_dir / safe_id
|
|
||||||
_save_spectrogram(wav, sr, spec_path)
|
|
||||||
record["spectrogram_path"] = str(spec_path.with_suffix(".png"))
|
|
||||||
|
|
||||||
|
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
|
||||||
record["status"] = "completed"
|
record["status"] = "completed"
|
||||||
|
avg = record["avg_metrics"]
|
||||||
|
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
|
||||||
|
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
||||||
|
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
|
||||||
|
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
record["status"] = "failed"
|
record["status"] = "failed"
|
||||||
record["error"] = str(e)
|
record["error"] = str(e)
|
||||||
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
pbar.update(n_clips - len(record["clips"]))
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Free generator copy immediately — large model, many adapters
|
|
||||||
try:
|
try:
|
||||||
del generator
|
del generator
|
||||||
except NameError:
|
except NameError:
|
||||||
@@ -365,7 +392,6 @@ class SelvaLoraEvaluator:
|
|||||||
|
|
||||||
summary["adapters"].append(record)
|
summary["adapters"].append(record)
|
||||||
_write_summary()
|
_write_summary()
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 5. Finalise summary
|
# 5. Finalise summary
|
||||||
@@ -379,8 +405,8 @@ class SelvaLoraEvaluator:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
||||||
if completed:
|
if completed:
|
||||||
ids = [r["id"] for r in completed]
|
ids = ["reference"] + [r["id"] for r in completed]
|
||||||
metrics_list = [r["spectral_metrics"] for r in completed]
|
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
|
||||||
chart_path = output_dir / "metric_comparison.png"
|
chart_path = output_dir / "metric_comparison.png"
|
||||||
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
||||||
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
||||||
|
|||||||
@@ -93,17 +93,16 @@ def _load_npz(path: Path) -> dict:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
||||||
num_steps: int = 25, seed: int = 42):
|
num_steps: int = 25, seed: int = 42, clip_idx: int = 0):
|
||||||
"""Run a quick no-CFG inference pass on a fixed training clip.
|
"""Run a quick no-CFG inference pass on a training clip.
|
||||||
|
|
||||||
Always uses dataset[0] and a fixed noise seed so samples across checkpoints
|
Uses dataset[clip_idx] and a fixed noise seed so samples across checkpoints
|
||||||
are directly comparable — you can hear the model improve step by step.
|
are directly comparable — you can hear the model improve step by step.
|
||||||
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
|
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
|
||||||
Uses fewer ODE steps than inference (8 vs 25) for speed.
|
|
||||||
"""
|
"""
|
||||||
generator.eval()
|
generator.eval()
|
||||||
try:
|
try:
|
||||||
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[0]
|
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
||||||
clip_f = clip_f_cpu.to(device, dtype)
|
clip_f = clip_f_cpu.to(device, dtype)
|
||||||
sync_f = sync_f_cpu.to(device, dtype)
|
sync_f = sync_f_cpu.to(device, dtype)
|
||||||
text_clip = text_clip_cpu.to(device, dtype)
|
text_clip = text_clip_cpu.to(device, dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user