feat: add SelVA LoRA Evaluator node
Generates audio samples from a list of adapters against a fixed reference
clip, collects spectral metrics for each, and outputs a comparison bar
chart + eval_summary.json. Useful for comparing sweep candidates before
committing to a next round of training.
JSON format: name, data_dir, output_dir, steps, seed, adapters[{id, path}].
Empty path = baseline (no LoRA).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ _NODES = {
|
||||
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
|
||||
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
|
||||
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
|
||||
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
|
||||
}
|
||||
|
||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||
|
||||
@@ -0,0 +1,365 @@
|
||||
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
|
||||
|
||||
JSON format:
|
||||
{
|
||||
"name": "eval_batch_1",
|
||||
"data_dir": "/path/to/features",
|
||||
"output_dir": "/path/to/evals/batch1",
|
||||
"steps": 25,
|
||||
"seed": 42,
|
||||
"adapters": [
|
||||
{"id": "baseline"},
|
||||
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
|
||||
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
|
||||
]
|
||||
}
|
||||
|
||||
Empty / missing "path" = baseline (no LoRA applied).
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||
from .selva_lora_trainer import (
|
||||
_prepare_dataset,
|
||||
_eval_sample,
|
||||
_spectral_metrics,
|
||||
_save_spectrogram,
|
||||
_pil_to_tensor,
|
||||
)
|
||||
from selva_core.model.lora import apply_lora, load_lora
|
||||
|
||||
|
||||
def _resolve_path(raw: str) -> Path:
|
||||
p = Path(raw.strip())
|
||||
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
||||
if not p.is_absolute() or unix_style_on_windows:
|
||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
||||
return p
|
||||
|
||||
|
||||
def _safe_stem(adapter_id: str) -> str:
|
||||
"""Replace characters illegal in filenames."""
|
||||
for ch in r'/\:*?"<>|':
|
||||
adapter_id = adapter_id.replace(ch, "_")
|
||||
return adapter_id
|
||||
|
||||
|
||||
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
|
||||
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
|
||||
|
||||
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
|
||||
"""
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
METRICS = [
|
||||
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
|
||||
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
|
||||
("spectral_flatness", "Spectral Flatness"),
|
||||
("temporal_variance", "Temporal Variance"),
|
||||
]
|
||||
COLORS = [
|
||||
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
|
||||
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
|
||||
]
|
||||
|
||||
n = len(adapter_ids)
|
||||
fig = Figure(figsize=(12, max(4, n * 0.6 + 2)), dpi=110, tight_layout=True)
|
||||
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
|
||||
|
||||
for ax, (key, title) in zip(axes, METRICS):
|
||||
values = []
|
||||
colors = []
|
||||
for i, m in enumerate(metrics_list):
|
||||
v = m.get(key, 0.0) if m else 0.0
|
||||
values.append(v)
|
||||
colors.append(COLORS[i % len(COLORS)])
|
||||
|
||||
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
|
||||
ax.set_title(title, fontsize=9)
|
||||
ax.set_xlabel(key, fontsize=8)
|
||||
ax.tick_params(axis="y", labelsize=7)
|
||||
ax.tick_params(axis="x", labelsize=7)
|
||||
|
||||
# Value labels on bars
|
||||
for bar, val in zip(bars, values):
|
||||
w = bar.get_width()
|
||||
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
|
||||
f"{val:.3f}", va="center", ha="left", fontsize=6)
|
||||
|
||||
canvas = FigureCanvasAgg(fig)
|
||||
canvas.draw()
|
||||
canvas.print_figure(str(output_path), dpi=110)
|
||||
|
||||
buf = canvas.buffer_rgba()
|
||||
w, h = canvas.get_width_height()
|
||||
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
|
||||
from PIL import Image
|
||||
return _pil_to_tensor(Image.fromarray(arr))
|
||||
|
||||
|
||||
class SelvaLoraEvaluator:
|
||||
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
|
||||
|
||||
Generates one audio sample per adapter, computes spectral metrics for each,
|
||||
and produces a comparison chart. Use this after a sweep to compare candidates
|
||||
before running the next round of training.
|
||||
"""
|
||||
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
FUNCTION = "run"
|
||||
RETURN_TYPES = ("STRING", "IMAGE")
|
||||
RETURN_NAMES = ("summary_path", "comparison_image")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"Path to eval_summary.json — contains spectral metrics per adapter.",
|
||||
"Bar chart comparing spectral metrics across all evaluated adapters.",
|
||||
)
|
||||
DESCRIPTION = (
|
||||
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
|
||||
"from a fixed reference clip, then collects spectral metrics for comparison. "
|
||||
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("SELVA_MODEL",),
|
||||
"eval_file": ("STRING", {
|
||||
"default": "eval_batch.json",
|
||||
"tooltip": (
|
||||
"Path to the JSON evaluation spec. Relative paths resolve "
|
||||
"to the ComfyUI output directory. "
|
||||
"Each adapter entry needs an 'id' and an optional 'path'. "
|
||||
"Omit 'path' for a no-LoRA baseline."
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
def run(self, model, eval_file):
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Resolve and parse the JSON file
|
||||
# ------------------------------------------------------------------
|
||||
eval_path = Path(eval_file.strip())
|
||||
if not eval_path.is_absolute():
|
||||
candidate = Path(folder_paths.models_dir) / eval_path
|
||||
if not candidate.exists():
|
||||
candidate = Path(folder_paths.get_output_directory()) / eval_path
|
||||
eval_path = candidate
|
||||
if not eval_path.exists():
|
||||
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
|
||||
|
||||
spec = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||
|
||||
if "adapters" not in spec or not spec["adapters"]:
|
||||
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
|
||||
for i, a in enumerate(spec["adapters"]):
|
||||
if "id" not in a:
|
||||
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
|
||||
|
||||
if "data_dir" not in spec:
|
||||
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
|
||||
if "output_dir" not in spec:
|
||||
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
|
||||
|
||||
name = spec.get("name", eval_path.stem)
|
||||
data_dir = _resolve_path(spec["data_dir"])
|
||||
output_dir = _resolve_path(spec["output_dir"])
|
||||
steps = int(spec.get("steps", 25))
|
||||
seed = int(spec.get("seed", 42))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
|
||||
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
|
||||
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Prepare dataset (VAE encode once)
|
||||
# ------------------------------------------------------------------
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
dataset = _prepare_dataset(model, data_dir, device)
|
||||
|
||||
feature_utils_orig = model["feature_utils"]
|
||||
seq_cfg = model["seq_cfg"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Build summary skeleton
|
||||
# ------------------------------------------------------------------
|
||||
summary = {
|
||||
"name": name,
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": None,
|
||||
"data_dir": str(data_dir),
|
||||
"output_dir": str(output_dir),
|
||||
"steps": steps,
|
||||
"seed": seed,
|
||||
"adapters": [],
|
||||
}
|
||||
summary_path = output_dir / "eval_summary.json"
|
||||
|
||||
def _write_summary():
|
||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
||||
|
||||
_write_summary()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Per-adapter evaluation loop
|
||||
# ------------------------------------------------------------------
|
||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]))
|
||||
|
||||
for adapter_spec in spec["adapters"]:
|
||||
adapter_id = adapter_spec["id"]
|
||||
adapter_path = (adapter_spec.get("path") or "").strip()
|
||||
safe_id = _safe_stem(adapter_id)
|
||||
|
||||
record = {
|
||||
"id": adapter_id,
|
||||
"path": adapter_path or None,
|
||||
"meta": None,
|
||||
"wav_path": None,
|
||||
"spectrogram_path": None,
|
||||
"spectral_metrics": None,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
print(f"[LoRA Evaluator] ── '{adapter_id}' ──", flush=True)
|
||||
|
||||
try:
|
||||
with torch.inference_mode(False):
|
||||
# 4a. Deep-copy generator
|
||||
generator = copy.deepcopy(model["generator"])
|
||||
|
||||
# 4b. Apply + load LoRA if path given
|
||||
if adapter_path:
|
||||
pt_path = Path(adapter_path)
|
||||
if not pt_path.is_absolute():
|
||||
pt_path = Path(folder_paths.base_path) / pt_path
|
||||
if not pt_path.exists():
|
||||
raise FileNotFoundError(f"Adapter not found: {pt_path}")
|
||||
|
||||
ckpt = torch.load(str(pt_path), map_location="cpu",
|
||||
weights_only=False)
|
||||
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
||||
state_dict = ckpt["state_dict"]
|
||||
meta = ckpt.get("meta", {})
|
||||
else:
|
||||
state_dict = ckpt
|
||||
meta = {}
|
||||
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
dropout = float(meta.get("lora_dropout", 0.0))
|
||||
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
||||
|
||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||
target_suffixes=tuple(target), dropout=dropout)
|
||||
if n == 0:
|
||||
raise RuntimeError(
|
||||
f"apply_lora matched 0 layers (target={target})"
|
||||
)
|
||||
load_lora(generator, state_dict)
|
||||
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
|
||||
f"(rank={rank}, {n} layers)", flush=True)
|
||||
else:
|
||||
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
||||
|
||||
# 4c. Move to device and set sequence lengths
|
||||
generator = generator.to(device, dtype)
|
||||
generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=seq_cfg.clip_seq_len,
|
||||
sync_seq_len=seq_cfg.sync_seq_len,
|
||||
)
|
||||
|
||||
# 4d. Run inference
|
||||
wav, sr = _eval_sample(
|
||||
generator, feature_utils_orig, dataset,
|
||||
seq_cfg, device, dtype,
|
||||
num_steps=steps, seed=seed,
|
||||
)
|
||||
|
||||
if wav is None:
|
||||
raise RuntimeError("_eval_sample returned None")
|
||||
|
||||
# 4e. Save wav
|
||||
wav_path = output_dir / f"{safe_id}.wav"
|
||||
try:
|
||||
torchaudio.save(str(wav_path), wav, sr)
|
||||
except RuntimeError:
|
||||
import soundfile as sf
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
record["wav_path"] = str(wav_path)
|
||||
print(f"[LoRA Evaluator] Saved {wav_path.name}", flush=True)
|
||||
|
||||
# 4f. Spectral metrics
|
||||
metrics = _spectral_metrics(wav, sr)
|
||||
record["spectral_metrics"] = metrics
|
||||
print(f"[LoRA Evaluator] hf={metrics['hf_energy_ratio']:.3f} "
|
||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||
f"flatness={metrics['spectral_flatness']:.3f} "
|
||||
f"tv={metrics['temporal_variance']:.3f}", flush=True)
|
||||
|
||||
# 4g. Spectrogram PNG
|
||||
spec_path = output_dir / safe_id
|
||||
_save_spectrogram(wav, sr, spec_path)
|
||||
record["spectrogram_path"] = str(spec_path.with_suffix(".png"))
|
||||
|
||||
record["status"] = "completed"
|
||||
|
||||
except Exception as e:
|
||||
record["status"] = "failed"
|
||||
record["error"] = str(e)
|
||||
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
# Free generator copy immediately — large model, many adapters
|
||||
try:
|
||||
del generator
|
||||
except NameError:
|
||||
pass
|
||||
soft_empty_cache()
|
||||
|
||||
summary["adapters"].append(record)
|
||||
_write_summary()
|
||||
pbar.update(1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Finalise summary
|
||||
# ------------------------------------------------------------------
|
||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
_write_summary()
|
||||
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Comparison chart
|
||||
# ------------------------------------------------------------------
|
||||
completed = [r for r in summary["adapters"] if r["status"] == "completed"]
|
||||
if completed:
|
||||
ids = [r["id"] for r in completed]
|
||||
metrics_list = [r["spectral_metrics"] for r in completed]
|
||||
chart_path = output_dir / "metric_comparison.png"
|
||||
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
||||
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
||||
else:
|
||||
from PIL import Image
|
||||
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
|
||||
|
||||
return (str(summary_path), comparison)
|
||||
Reference in New Issue
Block a user