d2e1ea7b80
Generates audio samples from a list of adapters against a fixed reference
clip, collects spectral metrics for each, and outputs a comparison bar
chart + eval_summary.json. Useful for comparing sweep candidates before
committing to a next round of training.
JSON format: name, data_dir, output_dir, steps, seed, adapters[{id, path}].
Empty path = baseline (no LoRA).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
366 lines
14 KiB
Python
366 lines
14 KiB
Python
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
|
||
|
||
JSON format:
|
||
{
|
||
"name": "eval_batch_1",
|
||
"data_dir": "/path/to/features",
|
||
"output_dir": "/path/to/evals/batch1",
|
||
"steps": 25,
|
||
"seed": 42,
|
||
"adapters": [
|
||
{"id": "baseline"},
|
||
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
|
||
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
|
||
]
|
||
}
|
||
|
||
Empty / missing "path" = baseline (no LoRA applied).
|
||
"""
|
||
|
||
import copy
|
||
import json
|
||
import sys
|
||
import traceback
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torchaudio
|
||
|
||
import comfy.utils
|
||
import folder_paths
|
||
|
||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||
from .selva_lora_trainer import (
|
||
_prepare_dataset,
|
||
_eval_sample,
|
||
_spectral_metrics,
|
||
_save_spectrogram,
|
||
_pil_to_tensor,
|
||
)
|
||
from selva_core.model.lora import apply_lora, load_lora
|
||
|
||
|
||
def _resolve_path(raw: str) -> Path:
|
||
p = Path(raw.strip())
|
||
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
||
if not p.is_absolute() or unix_style_on_windows:
|
||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
||
return p
|
||
|
||
|
||
def _safe_stem(adapter_id: str) -> str:
|
||
"""Replace characters illegal in filenames."""
|
||
for ch in r'/\:*?"<>|':
|
||
adapter_id = adapter_id.replace(ch, "_")
|
||
return adapter_id
|
||
|
||
|
||
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
|
||
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
|
||
|
||
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
|
||
"""
|
||
from matplotlib.figure import Figure
|
||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||
|
||
METRICS = [
|
||
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
|
||
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
|
||
("spectral_flatness", "Spectral Flatness"),
|
||
("temporal_variance", "Temporal Variance"),
|
||
]
|
||
COLORS = [
|
||
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
|
||
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
|
||
]
|
||
|
||
n = len(adapter_ids)
|
||
fig = Figure(figsize=(12, max(4, n * 0.6 + 2)), dpi=110, tight_layout=True)
|
||
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
|
||
|
||
for ax, (key, title) in zip(axes, METRICS):
|
||
values = []
|
||
colors = []
|
||
for i, m in enumerate(metrics_list):
|
||
v = m.get(key, 0.0) if m else 0.0
|
||
values.append(v)
|
||
colors.append(COLORS[i % len(COLORS)])
|
||
|
||
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
|
||
ax.set_title(title, fontsize=9)
|
||
ax.set_xlabel(key, fontsize=8)
|
||
ax.tick_params(axis="y", labelsize=7)
|
||
ax.tick_params(axis="x", labelsize=7)
|
||
|
||
# Value labels on bars
|
||
for bar, val in zip(bars, values):
|
||
w = bar.get_width()
|
||
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
|
||
f"{val:.3f}", va="center", ha="left", fontsize=6)
|
||
|
||
canvas = FigureCanvasAgg(fig)
|
||
canvas.draw()
|
||
canvas.print_figure(str(output_path), dpi=110)
|
||
|
||
buf = canvas.buffer_rgba()
|
||
w, h = canvas.get_width_height()
|
||
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
|
||
from PIL import Image
|
||
return _pil_to_tensor(Image.fromarray(arr))
|
||
|
||
|
||
class SelvaLoraEvaluator:
|
||
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
|
||
|
||
Generates one audio sample per adapter, computes spectral metrics for each,
|
||
and produces a comparison chart. Use this after a sweep to compare candidates
|
||
before running the next round of training.
|
||
"""
|
||
|
||
OUTPUT_NODE = True
|
||
CATEGORY = SELVA_CATEGORY
|
||
FUNCTION = "run"
|
||
RETURN_TYPES = ("STRING", "IMAGE")
|
||
RETURN_NAMES = ("summary_path", "comparison_image")
|
||
OUTPUT_TOOLTIPS = (
|
||
"Path to eval_summary.json — contains spectral metrics per adapter.",
|
||
"Bar chart comparing spectral metrics across all evaluated adapters.",
|
||
)
|
||
DESCRIPTION = (
|
||
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
|
||
"from a fixed reference clip, then collects spectral metrics for comparison. "
|
||
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
|
||
)
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"model": ("SELVA_MODEL",),
|
||
"eval_file": ("STRING", {
|
||
"default": "eval_batch.json",
|
||
"tooltip": (
|
||
"Path to the JSON evaluation spec. Relative paths resolve "
|
||
"to the ComfyUI output directory. "
|
||
"Each adapter entry needs an 'id' and an optional 'path'. "
|
||
"Omit 'path' for a no-LoRA baseline."
|
||
),
|
||
}),
|
||
}
|
||
}
|
||
|
||
def run(self, model, eval_file):
|
||
# ------------------------------------------------------------------
|
||
# 1. Resolve and parse the JSON file
|
||
# ------------------------------------------------------------------
|
||
eval_path = Path(eval_file.strip())
|
||
if not eval_path.is_absolute():
|
||
candidate = Path(folder_paths.models_dir) / eval_path
|
||
if not candidate.exists():
|
||
candidate = Path(folder_paths.get_output_directory()) / eval_path
|
||
eval_path = candidate
|
||
if not eval_path.exists():
|
||
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
|
||
|
||
spec = json.loads(eval_path.read_text(encoding="utf-8"))
|
||
|
||
if "adapters" not in spec or not spec["adapters"]:
|
||
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
|
||
for i, a in enumerate(spec["adapters"]):
|
||
if "id" not in a:
|
||
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
|
||
|
||
if "data_dir" not in spec:
|
||
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
|
||
if "output_dir" not in spec:
|
||
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
|
||
|
||
name = spec.get("name", eval_path.stem)
|
||
data_dir = _resolve_path(spec["data_dir"])
|
||
output_dir = _resolve_path(spec["output_dir"])
|
||
steps = int(spec.get("steps", 25))
|
||
seed = int(spec.get("seed", 42))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
|
||
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
|
||
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 2. Prepare dataset (VAE encode once)
|
||
# ------------------------------------------------------------------
|
||
device = get_device()
|
||
dtype = model["dtype"]
|
||
dataset = _prepare_dataset(model, data_dir, device)
|
||
|
||
feature_utils_orig = model["feature_utils"]
|
||
seq_cfg = model["seq_cfg"]
|
||
|
||
# ------------------------------------------------------------------
|
||
# 3. Build summary skeleton
|
||
# ------------------------------------------------------------------
|
||
summary = {
|
||
"name": name,
|
||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||
"completed_at": None,
|
||
"data_dir": str(data_dir),
|
||
"output_dir": str(output_dir),
|
||
"steps": steps,
|
||
"seed": seed,
|
||
"adapters": [],
|
||
}
|
||
summary_path = output_dir / "eval_summary.json"
|
||
|
||
def _write_summary():
|
||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
||
|
||
_write_summary()
|
||
|
||
# ------------------------------------------------------------------
|
||
# 4. Per-adapter evaluation loop
|
||
# ------------------------------------------------------------------
|
||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
||
|
||
for adapter_spec in spec["adapters"]:
|
||
adapter_id = adapter_spec["id"]
|
||
adapter_path = (adapter_spec.get("path") or "").strip()
|
||
safe_id = _safe_stem(adapter_id)
|
||
|
||
record = {
|
||
"id": adapter_id,
|
||
"path": adapter_path or None,
|
||
"meta": None,
|
||
"wav_path": None,
|
||
"spectrogram_path": None,
|
||
"spectral_metrics": None,
|
||
"status": "running",
|
||
}
|
||
|
||
print(f"[LoRA Evaluator] ── '{adapter_id}' ──", flush=True)
|
||
|
||
try:
|
||
with torch.inference_mode(False):
|
||
# 4a. Deep-copy generator
|
||
generator = copy.deepcopy(model["generator"])
|
||
|
||
# 4b. Apply + load LoRA if path given
|
||
if adapter_path:
|
||
pt_path = Path(adapter_path)
|
||
if not pt_path.is_absolute():
|
||
pt_path = Path(folder_paths.base_path) / pt_path
|
||
if not pt_path.exists():
|
||
raise FileNotFoundError(f"Adapter not found: {pt_path}")
|
||
|
||
ckpt = torch.load(str(pt_path), map_location="cpu",
|
||
weights_only=False)
|
||
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
||
state_dict = ckpt["state_dict"]
|
||
meta = ckpt.get("meta", {})
|
||
else:
|
||
state_dict = ckpt
|
||
meta = {}
|
||
|
||
rank = int(meta.get("rank", 16))
|
||
alpha = float(meta.get("alpha", float(rank)))
|
||
target = list(meta.get("target", ["attn.qkv"]))
|
||
dropout = float(meta.get("lora_dropout", 0.0))
|
||
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
||
|
||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||
target_suffixes=tuple(target), dropout=dropout)
|
||
if n == 0:
|
||
raise RuntimeError(
|
||
f"apply_lora matched 0 layers (target={target})"
|
||
)
|
||
load_lora(generator, state_dict)
|
||
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
|
||
f"(rank={rank}, {n} layers)", flush=True)
|
||
else:
|
||
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
||
|
||
# 4c. Move to device and set sequence lengths
|
||
generator = generator.to(device, dtype)
|
||
generator.update_seq_lengths(
|
||
latent_seq_len=seq_cfg.latent_seq_len,
|
||
clip_seq_len=seq_cfg.clip_seq_len,
|
||
sync_seq_len=seq_cfg.sync_seq_len,
|
||
)
|
||
|
||
# 4d. Run inference
|
||
wav, sr = _eval_sample(
|
||
generator, feature_utils_orig, dataset,
|
||
seq_cfg, device, dtype,
|
||
num_steps=steps, seed=seed,
|
||
)
|
||
|
||
if wav is None:
|
||
raise RuntimeError("_eval_sample returned None")
|
||
|
||
# 4e. Save wav
|
||
wav_path = output_dir / f"{safe_id}.wav"
|
||
try:
|
||
torchaudio.save(str(wav_path), wav, sr)
|
||
except RuntimeError:
|
||
import soundfile as sf
|
||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||
record["wav_path"] = str(wav_path)
|
||
print(f"[LoRA Evaluator] Saved {wav_path.name}", flush=True)
|
||
|
||
# 4f. Spectral metrics
|
||
metrics = _spectral_metrics(wav, sr)
|
||
record["spectral_metrics"] = metrics
|
||
print(f"[LoRA Evaluator] hf={metrics['hf_energy_ratio']:.3f} "
|
||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||
f"flatness={metrics['spectral_flatness']:.3f} "
|
||
f"tv={metrics['temporal_variance']:.3f}", flush=True)
|
||
|
||
# 4g. Spectrogram PNG
|
||
spec_path = output_dir / safe_id
|
||
_save_spectrogram(wav, sr, spec_path)
|
||
record["spectrogram_path"] = str(spec_path.with_suffix(".png"))
|
||
|
||
record["status"] = "completed"
|
||
|
||
except Exception as e:
|
||
record["status"] = "failed"
|
||
record["error"] = str(e)
|
||
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
||
traceback.print_exc()
|
||
|
||
finally:
|
||
# Free generator copy immediately — large model, many adapters
|
||
try:
|
||
del generator
|
||
except NameError:
|
||
pass
|
||
soft_empty_cache()
|
||
|
||
summary["adapters"].append(record)
|
||
_write_summary()
|
||
pbar.update(1)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 5. Finalise summary
|
||
# ------------------------------------------------------------------
|
||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||
_write_summary()
|
||
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 6. Comparison chart
|
||
# ------------------------------------------------------------------
|
||
completed = [r for r in summary["adapters"] if r["status"] == "completed"]
|
||
if completed:
|
||
ids = [r["id"] for r in completed]
|
||
metrics_list = [r["spectral_metrics"] for r in completed]
|
||
chart_path = output_dir / "metric_comparison.png"
|
||
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
||
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
||
else:
|
||
from PIL import Image
|
||
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
|
||
|
||
return (str(summary_path), comparison)
|