Files
ComfyUI-SelVA/nodes/selva_lora_scheduler.py
T
Ethanfel 3ec380a27e feat: add SelVA LoRA Scheduler node for automated experiment sweeps
- Extract _prepare_dataset() from SelvaLoraTrainer.train() as a module-level
  function so the dataset can be encoded once and reused across experiments
- Change _train_inner() return value from tuple to dict (adds loss_history,
  meta, completed; train() unpacks for ComfyUI — no change to node outputs)
- New SelvaLoraScheduler node: reads a JSON sweep file, runs N experiments
  sequentially, writes experiment_summary.json (updated after each run) and
  loss_comparison.png with all smoothed curves overlaid on the same axes
- Register SelvaLoraScheduler in nodes/__init__.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:03:21 +02:00

437 lines
17 KiB
Python

"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "tier1_sweep",
"description": "optional human note",
"data_dir": "dataset/dog_bark",
"output_root": "lora_output/tier1_sweep",
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
...
]
}
"""
import copy
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
SelvaLoraTrainer,
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"alpha": 0.0,
"target": "attn.qkv",
"batch_size": 4,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"resume_path": "",
"seed": 42,
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
}
# Palette for comparison chart: one color per experiment (cycles if > 8)
_PALETTE = [
(66, 133, 244), # blue
(234, 67, 53), # red
(52, 168, 83), # green
(251, 188, 5), # yellow
(155, 89, 182), # purple
(26, 188, 156), # teal
(230, 126, 34), # orange
(149, 165, 166), # grey
]
def _resolve_path(raw: str) -> Path:
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge base defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
# Don't carry id/description into the training params
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
start_step: int, total_steps: int) -> dict:
"""Build a dict of {step: loss} at each save_every boundary.
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
start + (i+1)*log_interval].
"""
result = {}
targets = range(save_every, total_steps + 1, save_every)
for target in targets:
# index of the loss entry nearest to this step
idx = (target - start_step) // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
) -> Image.Image:
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
# Collect all smoothed series
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"log_interval": ed.get("log_interval", 50),
"start_step": ed.get("start_step", 0),
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
# Draw each curve
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
# Legend (right side)
lx = W - pr + 10
ly = pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
class SelvaLoraScheduler:
"""Runs a sweep of LoRA training experiments defined in a JSON file.
The dataset (VAE encoding + .npz loading) is performed once and shared
across all experiments. Each experiment deep-copies the generator and trains
independently. Results are written to `experiment_summary.json` after every
completed run so partial results are preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
"The dataset is encoded once and reused across all experiments. "
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is. "
"See LORA_TRAINING.md for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
# Try relative to ComfyUI models dir first, then output dir
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[LoRA Scheduler] {description}", flush=True)
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load + encode dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build the summary skeleton (written incrementally)
# ------------------------------------------------------------------
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
summary_path = output_root / "experiment_summary.json"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaLoraTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
all_curve_data = [] # collected for comparison image
log_interval = 50 # matches _train_inner
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
variant = model["variant"]
mode = model["mode"]
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
cfg = _merge_config(base_cfg, exp)
# Required training params
steps = int(cfg.get("steps", 2000))
rank = int(cfg.get("rank", 16))
lr = float(cfg.get("lr", 1e-4))
alpha = float(cfg.get("alpha", 0.0))
target = str(cfg.get("target", "attn.qkv"))
batch_size = int(cfg.get("batch_size", 4))
warmup = int(cfg.get("warmup_steps", 100))
grad_accum = int(cfg.get("grad_accum", 1))
save_every = int(cfg.get("save_every", 500))
resume_path = str(cfg.get("resume_path", ""))
seed = int(cfg.get("seed", 42))
ts_mode = str(cfg.get("timestep_mode", "uniform"))
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
curr_switch = float(cfg.get("curriculum_switch", 0.6))
dropout = float(cfg.get("lora_dropout", 0.0))
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
alpha_val = alpha if alpha > 0.0 else float(rank)
target_suffixes = tuple(target.strip().split())
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
"batch_size": batch_size, "warmup_steps": warmup,
"grad_accum": grad_accum, "save_every": save_every,
"seed": seed, "target": list(target_suffixes),
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
"curriculum_switch": curr_switch,
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
},
"results": {"status": "running"},
"adapter_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup,
grad_accum, save_every, resume_path, seed,
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
# Compute summary metrics
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, 0, steps
),
"duration_seconds": round(duration, 1),
}
exp_record["adapter_path"] = r["adapter_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
# Continue to next experiment rather than aborting the whole sweep
continue
_write_summary()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)