From fdce9cbbf1161d98af8b0336ca0d19453120022d Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 17:42:55 +0200 Subject: [PATCH] feat: evaluate adapters on all dataset clips, not just clip_001 - _eval_sample gains clip_idx param (default 0, backward compatible) - Evaluator loops over all dataset clips per adapter, saves one WAV per clip - Reference metrics computed for all clips and averaged - Comparison chart and summary use avg_metrics across all clips - Eliminates bias from evaluating on an unrepresentative single clip Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_evaluator.py | 170 ++++++++++++++++++++-------------- nodes/selva_lora_trainer.py | 9 +- 2 files changed, 102 insertions(+), 77 deletions(-) diff --git a/nodes/selva_lora_evaluator.py b/nodes/selva_lora_evaluator.py index 42694e7..4360570 100644 --- a/nodes/selva_lora_evaluator.py +++ b/nodes/selva_lora_evaluator.py @@ -44,6 +44,16 @@ from .selva_lora_trainer import ( from selva_core.model.lora import apply_lora, load_lora +def _avg_metrics(metrics_list: list) -> dict: + """Average spectral metrics across multiple clips, ignoring None entries.""" + keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz", + "spectral_flatness", "temporal_variance"] + valid = [m for m in metrics_list if m] + if not valid: + return {} + return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys} + + def _resolve_path(raw: str) -> Path: p = Path(raw.strip()) unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive @@ -200,42 +210,56 @@ class SelvaLoraEvaluator: seq_cfg = model["seq_cfg"] # ------------------------------------------------------------------ - # 3. Load reference audio (first clip, same one used for eval samples) + # 3. Collect reference metrics for all dataset clips # ------------------------------------------------------------------ - first_npz = sorted(data_dir.glob("*.npz"))[0] - audio_path = _find_audio(first_npz) - ref_record = {"id": "reference", "path": str(audio_path) if audio_path else None, - "wav_path": None, "spectrogram_path": None, - "spectral_metrics": None, "status": "failed"} - if audio_path: + import shutil + npz_files = sorted(data_dir.glob("*.npz")) + ref_dir = output_dir / "reference" + ref_dir.mkdir(exist_ok=True) + ref_clips = [] # list of {clip, wav_path, spectral_metrics} + + print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...", + flush=True) + for npz_path in npz_files: + audio_path = _find_audio(npz_path) + if audio_path is None: + continue try: ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) ref_wav = ref_wav.unsqueeze(0) # [1, L] - import shutil - ref_out = output_dir / f"reference{audio_path.suffix}" + ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}" shutil.copy2(str(audio_path), str(ref_out)) - ref_record["wav_path"] = str(ref_out) - ref_record["spectral_metrics"] = _spectral_metrics(ref_wav, seq_cfg.sampling_rate) - _save_spectrogram(ref_wav, seq_cfg.sampling_rate, output_dir / "reference") - ref_record["spectrogram_path"] = str((output_dir / "reference").with_suffix(".png")) - ref_record["status"] = "completed" - print(f"[LoRA Evaluator] Reference: {audio_path.name}", flush=True) + metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate) + ref_clips.append({ + "clip": npz_path.stem, + "wav_path": str(ref_out), + "spectral_metrics": metrics, + }) except Exception as e: - print(f"[LoRA Evaluator] Reference load failed: {e}", flush=True) + print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True) + + # Average reference metrics across all clips + ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips]) + print(f"[LoRA Evaluator] Reference avg — " + f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz " + f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} " + f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True) # ------------------------------------------------------------------ # 4. Build summary skeleton # ------------------------------------------------------------------ summary = { - "name": name, - "started_at": datetime.now(timezone.utc).isoformat(), - "completed_at": None, - "data_dir": str(data_dir), - "output_dir": str(output_dir), - "reference": first_npz.name, - "steps": steps, - "seed": seed, - "adapters": [ref_record], + "name": name, + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": None, + "data_dir": str(data_dir), + "output_dir": str(output_dir), + "n_clips": len(ref_clips), + "steps": steps, + "seed": seed, + "reference_avg": ref_avg, + "reference_clips": ref_clips, + "adapters": [], } summary_path = output_dir / "eval_summary.json" @@ -245,33 +269,33 @@ class SelvaLoraEvaluator: _write_summary() # ------------------------------------------------------------------ - # 5. Per-adapter evaluation loop + # 5. Per-adapter evaluation loop (all clips) # ------------------------------------------------------------------ - pbar = comfy.utils.ProgressBar(len(spec["adapters"])) + n_clips = len(dataset) + pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips) for adapter_spec in spec["adapters"]: adapter_id = adapter_spec["id"] adapter_path = (adapter_spec.get("path") or "").strip() safe_id = _safe_stem(adapter_id) + clip_dir = output_dir / safe_id + clip_dir.mkdir(exist_ok=True) record = { - "id": adapter_id, - "path": adapter_path or None, - "meta": None, - "wav_path": None, - "spectrogram_path": None, - "spectral_metrics": None, - "status": "running", + "id": adapter_id, + "path": adapter_path or None, + "meta": None, + "clips": [], + "avg_metrics": None, + "status": "running", } - print(f"[LoRA Evaluator] ── '{adapter_id}' ──", flush=True) + print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True) try: with torch.inference_mode(False): - # 4a. Deep-copy generator generator = copy.deepcopy(model["generator"]) - # 4b. Apply + load LoRA if path given if adapter_path: pt_path = Path(adapter_path) if not pt_path.is_absolute(): @@ -306,7 +330,6 @@ class SelvaLoraEvaluator: else: print("[LoRA Evaluator] Baseline (no LoRA)", flush=True) - # 4c. Move to device and set sequence lengths generator = generator.to(device, dtype) generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, @@ -314,49 +337,53 @@ class SelvaLoraEvaluator: sync_seq_len=seq_cfg.sync_seq_len, ) - # 4d. Run inference - wav, sr = _eval_sample( - generator, feature_utils_orig, dataset, - seq_cfg, device, dtype, - num_steps=steps, seed=seed, - ) + clip_metrics_list = [] + for clip_idx in range(n_clips): + clip_stem = npz_files[clip_idx].stem + wav, sr = _eval_sample( + generator, feature_utils_orig, dataset, + seq_cfg, device, dtype, + num_steps=steps, seed=seed, clip_idx=clip_idx, + ) + if wav is None: + pbar.update(1) + continue - if wav is None: - raise RuntimeError("_eval_sample returned None") + wav_path = clip_dir / f"{clip_stem}.wav" + try: + torchaudio.save(str(wav_path), wav, sr) + except RuntimeError: + import soundfile as sf + sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) - # 4e. Save wav - wav_path = output_dir / f"{safe_id}.wav" - try: - torchaudio.save(str(wav_path), wav, sr) - except RuntimeError: - import soundfile as sf - sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) - record["wav_path"] = str(wav_path) - print(f"[LoRA Evaluator] Saved {wav_path.name}", flush=True) - - # 4f. Spectral metrics - metrics = _spectral_metrics(wav, sr) - record["spectral_metrics"] = metrics - print(f"[LoRA Evaluator] hf={metrics['hf_energy_ratio']:.3f} " - f"centroid={metrics['spectral_centroid_hz']:.0f}Hz " - f"flatness={metrics['spectral_flatness']:.3f} " - f"tv={metrics['temporal_variance']:.3f}", flush=True) - - # 4g. Spectrogram PNG - spec_path = output_dir / safe_id - _save_spectrogram(wav, sr, spec_path) - record["spectrogram_path"] = str(spec_path.with_suffix(".png")) + metrics = _spectral_metrics(wav, sr) + clip_metrics_list.append(metrics) + record["clips"].append({ + "clip": clip_stem, + "wav_path": str(wav_path), + "spectral_metrics": metrics, + }) + print(f" [{clip_idx+1}/{n_clips}] {clip_stem} " + f"centroid={metrics['spectral_centroid_hz']:.0f}Hz " + f"hf={metrics['hf_energy_ratio']:.3f}", flush=True) + pbar.update(1) + record["avg_metrics"] = _avg_metrics(clip_metrics_list) record["status"] = "completed" + avg = record["avg_metrics"] + print(f"[LoRA Evaluator] '{adapter_id}' avg — " + f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz " + f"hf={avg.get('hf_energy_ratio', 0):.3f} " + f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True) except Exception as e: record["status"] = "failed" record["error"] = str(e) print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True) traceback.print_exc() + pbar.update(n_clips - len(record["clips"])) finally: - # Free generator copy immediately — large model, many adapters try: del generator except NameError: @@ -365,7 +392,6 @@ class SelvaLoraEvaluator: summary["adapters"].append(record) _write_summary() - pbar.update(1) # ------------------------------------------------------------------ # 5. Finalise summary @@ -379,8 +405,8 @@ class SelvaLoraEvaluator: # ------------------------------------------------------------------ completed = [r for r in summary["adapters"] if r.get("status") == "completed"] if completed: - ids = [r["id"] for r in completed] - metrics_list = [r["spectral_metrics"] for r in completed] + ids = ["reference"] + [r["id"] for r in completed] + metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed] chart_path = output_dir / "metric_comparison.png" comparison = _draw_metric_comparison(ids, metrics_list, chart_path) print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index c763661..76fbe1a 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -93,17 +93,16 @@ def _load_npz(path: Path) -> dict: # --------------------------------------------------------------------------- def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype, - num_steps: int = 25, seed: int = 42): - """Run a quick no-CFG inference pass on a fixed training clip. + num_steps: int = 25, seed: int = 42, clip_idx: int = 0): + """Run a quick no-CFG inference pass on a training clip. - Always uses dataset[0] and a fixed noise seed so samples across checkpoints + Uses dataset[clip_idx] and a fixed noise seed so samples across checkpoints are directly comparable — you can hear the model improve step by step. Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure. - Uses fewer ODE steps than inference (8 vs 25) for speed. """ generator.eval() try: - _, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[0] + _, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx] clip_f = clip_f_cpu.to(device, dtype) sync_f = sync_f_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype)