Files
ComfyUI-SelVA/nodes/selva_ti_scheduler.py
T
Ethanfel 92535deab2 fix(ti-scheduler): save comparison image after each completed experiment
Previously the comparison PNG was only written at the very end of the sweep,
so an interrupted run produced no image at all. Now _save_comparison() is
called right after _write_summary() for every successful experiment, keeping
loss_comparison.png current throughout the sweep.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:39:30 +02:00

480 lines
18 KiB
Python

"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "ti_sweep_1",
"description": "optional human note",
"data_dir": "dataset/bj_sounds",
"output_root": "ti_output/sweep_1",
"base": {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000
},
"experiments": [
{"id": "baseline", "description": "default 4 tokens"},
{"id": "n8_tokens", "n_tokens": 8},
{"id": "lr_5e4", "lr": 5e-4},
{"id": "warm_init", "init_text": "industrial sound design"},
{"id": "n4_more_steps", "steps": 5000}
]
}
"""
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
# ---------------------------------------------------------------------------
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
# ---------------------------------------------------------------------------
def _get_system_info() -> dict:
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
_PARAM_DEFAULTS = {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000,
"init_text": "",
"inject_mode": "suffix",
}
_PALETTE = [
(66, 133, 244),
(234, 67, 53),
(52, 168, 83),
(251, 188, 5),
(155, 89, 182),
(26, 188, 156),
(230, 126, 34),
(149, 165, 166),
]
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
result = {}
for target in range(save_every, total_steps + 1, save_every):
idx = target // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
from PIL import Image, ImageDraw
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
lx, ly = W - pr + 10, pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTiScheduler:
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
The dataset is loaded once and reused. Each experiment calls
SelvaTextualInversionTrainer._train_inner() with its own config.
Results are written to experiment_summary.json after every completed run.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — compare runs across sweeps.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
"The dataset is encoded once and reused. Results (loss, config, embeddings "
"paths) are collected in experiment_summary.json after each run."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "ti_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"output directory. See node description for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate JSON
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[TI Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[TI Scheduler] {description}", flush=True)
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
all_curve_data.append({
"id": rec["id"],
"loss_history": rec["results"].get("loss_history", []),
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
comparison_img_path = output_root / "loss_comparison.png"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
def _save_comparison():
try:
img = _draw_comparison_curves(all_curve_data)
img.save(str(comparison_img_path))
except Exception as e:
print(f"[TI Scheduler] Comparison image failed: {e}", flush=True)
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaTextualInversionTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
n_tokens = int(cfg["n_tokens"])
lr = float(cfg["lr"])
steps = int(cfg["steps"])
batch_size = int(cfg["batch_size"])
warmup = int(cfg["warmup_steps"])
seed = int(cfg["seed"])
save_every = int(cfg["save_every"])
init_text = str(cfg["init_text"])
inject_mode = str(cfg["inject_mode"])
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / "embeddings.pt"
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[TI Scheduler] {exp_desc}", flush=True)
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
f"batch_size={batch_size} warmup={warmup} seed={seed} "
f"inject_mode={inject_mode}", flush=True)
if init_text:
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"n_tokens": n_tokens,
"lr": lr,
"steps": steps,
"batch_size": batch_size,
"warmup_steps": warmup,
"seed": seed,
"save_every": save_every,
"init_text": init_text,
"inject_mode": inject_mode,
},
"results": {"status": "running"},
"embeddings_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup, seed, save_every, init_text, inject_mode,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
loss_std_last_quarter = None
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["embeddings_path"] = r["embeddings_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
_write_summary()
_save_comparison()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image (final update, then return to ComfyUI)
# ------------------------------------------------------------------
_save_comparison()
comparison_img = _draw_comparison_curves(all_curve_data)
return (str(summary_path), _pil_to_tensor(comparison_img))