feat: add SelVA TI Scheduler for sweep-based textual inversion experiments

- SelvaTiScheduler: runs a JSON-defined sweep of TI training experiments,
  loading the dataset once and reusing it across runs
- Collects per-experiment loss history, final/min loss, stability metric
  (loss_std_last_quarter), and duration — written to experiment_summary.json
  after each completed run so partial sweeps survive interruption
- Resume-aware: skips experiments already marked completed in an existing
  summary file
- Outputs smoothed loss comparison chart (same axes, one curve per experiment)
- SelvaTextualInversionTrainer._train_inner now returns a dict
  {embeddings_path, loss_history} so the scheduler can read results;
  train() extracts just the path for ComfyUI

JSON format: name, description, data_dir, output_root, base config,
experiments list with id + param overrides

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 23:13:04 +02:00
parent bb07bc8169
commit e37bfe1b1c
3 changed files with 474 additions and 2 deletions
+1
View File
@@ -16,6 +16,7 @@ _NODES = {
"SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"), "SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"),
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"), "SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"), "SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
} }
for key, (module_path, class_name, display_name) in _NODES.items(): for key, (module_path, class_name, display_name) in _NODES.items():
+6 -2
View File
@@ -201,13 +201,14 @@ class SelvaTextualInversionTrainer:
# Training must run outside inference_mode so autograd works # Training must run outside inference_mode so autograd works
with torch.inference_mode(False), torch.enable_grad(): with torch.inference_mode(False), torch.enable_grad():
return self._train_inner( r = self._train_inner(
model, dataset, feature_utils_orig, seq_cfg, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode, device, dtype, mode,
data_dir, out_path, data_dir, out_path,
n_tokens, steps, lr, batch_size, n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, warmup_steps, seed, save_every, init_text,
) )
return (r["embeddings_path"],)
def _train_inner( def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg, self, model, dataset, feature_utils_orig, seq_cfg,
@@ -368,4 +369,7 @@ class SelvaTextualInversionTrainer:
print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True) print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True)
soft_empty_cache() soft_empty_cache()
return (str(out_path),) return {
"embeddings_path": str(out_path),
"loss_history": loss_history,
}
+467
View File
@@ -0,0 +1,467 @@
"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "ti_sweep_1",
"description": "optional human note",
"data_dir": "dataset/bj_sounds",
"output_root": "ti_output/sweep_1",
"base": {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000
},
"experiments": [
{"id": "baseline", "description": "default 4 tokens"},
{"id": "n8_tokens", "n_tokens": 8},
{"id": "lr_5e4", "lr": 5e-4},
{"id": "warm_init", "init_text": "industrial sound design"},
{"id": "n4_more_steps", "steps": 5000}
]
}
"""
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
# ---------------------------------------------------------------------------
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
# ---------------------------------------------------------------------------
def _get_system_info() -> dict:
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
_PARAM_DEFAULTS = {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000,
"init_text": "",
}
_PALETTE = [
(66, 133, 244),
(234, 67, 53),
(52, 168, 83),
(251, 188, 5),
(155, 89, 182),
(26, 188, 156),
(230, 126, 34),
(149, 165, 166),
]
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
result = {}
for target in range(save_every, total_steps + 1, save_every):
idx = target // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
from PIL import Image, ImageDraw
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
lx, ly = W - pr + 10, pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTiScheduler:
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
The dataset is loaded once and reused. Each experiment calls
SelvaTextualInversionTrainer._train_inner() with its own config.
Results are written to experiment_summary.json after every completed run.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — compare runs across sweeps.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
"The dataset is encoded once and reused. Results (loss, config, embeddings "
"paths) are collected in experiment_summary.json after each run."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "ti_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"output directory. See node description for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate JSON
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[TI Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[TI Scheduler] {description}", flush=True)
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
all_curve_data.append({
"id": rec["id"],
"loss_history": rec["results"].get("loss_history", []),
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaTextualInversionTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
n_tokens = int(cfg["n_tokens"])
lr = float(cfg["lr"])
steps = int(cfg["steps"])
batch_size = int(cfg["batch_size"])
warmup = int(cfg["warmup_steps"])
seed = int(cfg["seed"])
save_every = int(cfg["save_every"])
init_text = str(cfg["init_text"])
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / "embeddings.pt"
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[TI Scheduler] {exp_desc}", flush=True)
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
f"batch_size={batch_size} warmup={warmup} seed={seed}", flush=True)
if init_text:
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"n_tokens": n_tokens,
"lr": lr,
"steps": steps,
"batch_size": batch_size,
"warmup_steps": warmup,
"seed": seed,
"save_every": save_every,
"init_text": init_text,
},
"results": {"status": "running"},
"embeddings_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup, seed, save_every, init_text,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
loss_std_last_quarter = None
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["embeddings_path"] = r["embeddings_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
_write_summary()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)