Files
ComfyUI-SelVA/nodes/selva_lora_evaluator.py
T
Ethanfel fdce9cbbf1 feat: evaluate adapters on all dataset clips, not just clip_001
- _eval_sample gains clip_idx param (default 0, backward compatible)
- Evaluator loops over all dataset clips per adapter, saves one WAV per clip
- Reference metrics computed for all clips and averaged
- Comparison chart and summary use avg_metrics across all clips
- Eliminates bias from evaluating on an unrepresentative single clip

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:42:55 +02:00

418 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
JSON format:
{
"name": "eval_batch_1",
"data_dir": "/path/to/features",
"output_dir": "/path/to/evals/batch1",
"steps": 25,
"seed": 42,
"adapters": [
{"id": "baseline"},
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
]
}
Empty / missing "path" = baseline (no LoRA applied).
"""
import copy
import json
import sys
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_pil_to_tensor,
_find_audio,
_load_audio,
)
from selva_core.model.lora import apply_lora, load_lora
def _avg_metrics(metrics_list: list) -> dict:
"""Average spectral metrics across multiple clips, ignoring None entries."""
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
"spectral_flatness", "temporal_variance"]
valid = [m for m in metrics_list if m]
if not valid:
return {}
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _safe_stem(adapter_id: str) -> str:
"""Replace characters illegal in filenames."""
for ch in r'/\:*?"<>|':
adapter_id = adapter_id.replace(ch, "_")
return adapter_id
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
METRICS = [
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
("spectral_flatness", "Spectral Flatness"),
("temporal_variance", "Temporal Variance"),
]
COLORS = [
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
]
fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True)
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
for ax, (key, title) in zip(axes, METRICS):
values = []
colors = []
for i, m in enumerate(metrics_list):
v = m.get(key, 0.0) if m else 0.0
values.append(v)
colors.append(COLORS[i % len(COLORS)])
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
ax.set_title(title, fontsize=9)
ax.set_xlabel(key, fontsize=8)
ax.tick_params(axis="y", labelsize=7)
ax.tick_params(axis="x", labelsize=7)
# Value labels on bars
for bar, val in zip(bars, values):
w = bar.get_width()
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
f"{val:.3f}", va="center", ha="left", fontsize=6)
canvas = FigureCanvasAgg(fig)
canvas.draw()
canvas.print_figure(str(output_path), dpi=110)
buf = canvas.buffer_rgba()
w, h = canvas.get_width_height()
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
from PIL import Image
return _pil_to_tensor(Image.fromarray(arr))
class SelvaLoraEvaluator:
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
Generates one audio sample per adapter, computes spectral metrics for each,
and produces a comparison chart. Use this after a sweep to compare candidates
before running the next round of training.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_image")
OUTPUT_TOOLTIPS = (
"Path to eval_summary.json — contains spectral metrics per adapter.",
"Bar chart comparing spectral metrics across all evaluated adapters.",
)
DESCRIPTION = (
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
"from a fixed reference clip, then collects spectral metrics for comparison. "
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"eval_file": ("STRING", {
"default": "eval_batch.json",
"tooltip": (
"Path to the JSON evaluation spec. Relative paths resolve "
"to the ComfyUI output directory. "
"Each adapter entry needs an 'id' and an optional 'path'. "
"Omit 'path' for a no-LoRA baseline."
),
}),
}
}
def run(self, model, eval_file):
# ------------------------------------------------------------------
# 1. Resolve and parse the JSON file
# ------------------------------------------------------------------
eval_path = Path(eval_file.strip())
if not eval_path.is_absolute():
candidate = Path(folder_paths.models_dir) / eval_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / eval_path
eval_path = candidate
if not eval_path.exists():
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
spec = json.loads(eval_path.read_text(encoding="utf-8"))
if "adapters" not in spec or not spec["adapters"]:
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
for i, a in enumerate(spec["adapters"]):
if "id" not in a:
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
if "data_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
if "output_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
name = spec.get("name", eval_path.stem)
data_dir = _resolve_path(spec["data_dir"])
output_dir = _resolve_path(spec["output_dir"])
steps = int(spec.get("steps", 25))
seed = int(spec.get("seed", 42))
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
# ------------------------------------------------------------------
# 2. Prepare dataset (VAE encode once)
# ------------------------------------------------------------------
device = get_device()
dtype = model["dtype"]
dataset = _prepare_dataset(model, data_dir, device)
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
# ------------------------------------------------------------------
# 3. Collect reference metrics for all dataset clips
# ------------------------------------------------------------------
import shutil
npz_files = sorted(data_dir.glob("*.npz"))
ref_dir = output_dir / "reference"
ref_dir.mkdir(exist_ok=True)
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
flush=True)
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
continue
try:
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
ref_wav = ref_wav.unsqueeze(0) # [1, L]
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
shutil.copy2(str(audio_path), str(ref_out))
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
ref_clips.append({
"clip": npz_path.stem,
"wav_path": str(ref_out),
"spectral_metrics": metrics,
})
except Exception as e:
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
# Average reference metrics across all clips
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
print(f"[LoRA Evaluator] Reference avg — "
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
# ------------------------------------------------------------------
# 4. Build summary skeleton
# ------------------------------------------------------------------
summary = {
"name": name,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"data_dir": str(data_dir),
"output_dir": str(output_dir),
"n_clips": len(ref_clips),
"steps": steps,
"seed": seed,
"reference_avg": ref_avg,
"reference_clips": ref_clips,
"adapters": [],
}
summary_path = output_dir / "eval_summary.json"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Per-adapter evaluation loop (all clips)
# ------------------------------------------------------------------
n_clips = len(dataset)
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
for adapter_spec in spec["adapters"]:
adapter_id = adapter_spec["id"]
adapter_path = (adapter_spec.get("path") or "").strip()
safe_id = _safe_stem(adapter_id)
clip_dir = output_dir / safe_id
clip_dir.mkdir(exist_ok=True)
record = {
"id": adapter_id,
"path": adapter_path or None,
"meta": None,
"clips": [],
"avg_metrics": None,
"status": "running",
}
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
try:
with torch.inference_mode(False):
generator = copy.deepcopy(model["generator"])
if adapter_path:
pt_path = Path(adapter_path)
if not pt_path.is_absolute():
pt_path = Path(folder_paths.base_path) / pt_path
if not pt_path.exists():
raise FileNotFoundError(f"Adapter not found: {pt_path}")
ckpt = torch.load(str(pt_path), map_location="cpu",
weights_only=False)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
meta = ckpt.get("meta", {})
else:
state_dict = ckpt
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
dropout = float(meta.get("lora_dropout", 0.0))
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target), dropout=dropout)
if n == 0:
raise RuntimeError(
f"apply_lora matched 0 layers (target={target})"
)
load_lora(generator, state_dict)
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
f"(rank={rank}, {n} layers)", flush=True)
else:
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
generator = generator.to(device, dtype)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
clip_metrics_list = []
for clip_idx in range(n_clips):
clip_stem = npz_files[clip_idx].stem
wav, sr = _eval_sample(
generator, feature_utils_orig, dataset,
seq_cfg, device, dtype,
num_steps=steps, seed=seed, clip_idx=clip_idx,
)
if wav is None:
pbar.update(1)
continue
wav_path = clip_dir / f"{clip_stem}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
metrics = _spectral_metrics(wav, sr)
clip_metrics_list.append(metrics)
record["clips"].append({
"clip": clip_stem,
"wav_path": str(wav_path),
"spectral_metrics": metrics,
})
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
pbar.update(1)
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
record["status"] = "completed"
avg = record["avg_metrics"]
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
except Exception as e:
record["status"] = "failed"
record["error"] = str(e)
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
traceback.print_exc()
pbar.update(n_clips - len(record["clips"]))
finally:
try:
del generator
except NameError:
pass
soft_empty_cache()
summary["adapters"].append(record)
_write_summary()
# ------------------------------------------------------------------
# 5. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 6. Comparison chart
# ------------------------------------------------------------------
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
if completed:
ids = ["reference"] + [r["id"] for r in completed]
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
chart_path = output_dir / "metric_comparison.png"
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
else:
from PIL import Image
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
return (str(summary_path), comparison)