From e37bfe1b1c128c02d313e25c42411d3552347788 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 23:13:04 +0200 Subject: [PATCH] feat: add SelVA TI Scheduler for sweep-based textual inversion experiments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SelvaTiScheduler: runs a JSON-defined sweep of TI training experiments, loading the dataset once and reusing it across runs - Collects per-experiment loss history, final/min loss, stability metric (loss_std_last_quarter), and duration — written to experiment_summary.json after each completed run so partial sweeps survive interruption - Resume-aware: skips experiments already marked completed in an existing summary file - Outputs smoothed loss comparison chart (same axes, one curve per experiment) - SelvaTextualInversionTrainer._train_inner now returns a dict {embeddings_path, loss_history} so the scheduler can read results; train() extracts just the path for ComfyUI JSON format: name, description, data_dir, output_root, base config, experiments list with id + param overrides Co-Authored-By: Claude Sonnet 4.6 --- nodes/__init__.py | 1 + nodes/selva_textual_inversion_trainer.py | 8 +- nodes/selva_ti_scheduler.py | 467 +++++++++++++++++++++++ 3 files changed, 474 insertions(+), 2 deletions(-) create mode 100644 nodes/selva_ti_scheduler.py diff --git a/nodes/__init__.py b/nodes/__init__.py index c696387..4fe05fd 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -16,6 +16,7 @@ _NODES = { "SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"), "SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"), "SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"), + "SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_textual_inversion_trainer.py b/nodes/selva_textual_inversion_trainer.py index be7dc02..41a8fd0 100644 --- a/nodes/selva_textual_inversion_trainer.py +++ b/nodes/selva_textual_inversion_trainer.py @@ -201,13 +201,14 @@ class SelvaTextualInversionTrainer: # Training must run outside inference_mode so autograd works with torch.inference_mode(False), torch.enable_grad(): - return self._train_inner( + r = self._train_inner( model, dataset, feature_utils_orig, seq_cfg, device, dtype, mode, data_dir, out_path, n_tokens, steps, lr, batch_size, warmup_steps, seed, save_every, init_text, ) + return (r["embeddings_path"],) def _train_inner( self, model, dataset, feature_utils_orig, seq_cfg, @@ -368,4 +369,7 @@ class SelvaTextualInversionTrainer: print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True) soft_empty_cache() - return (str(out_path),) + return { + "embeddings_path": str(out_path), + "loss_history": loss_history, + } diff --git a/nodes/selva_ti_scheduler.py b/nodes/selva_ti_scheduler.py new file mode 100644 index 0000000..75a1b4d --- /dev/null +++ b/nodes/selva_ti_scheduler.py @@ -0,0 +1,467 @@ +"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file. + +Each experiment inherits from a shared `base` config and overrides specific keys. +The dataset is loaded once and reused across all experiments. Results are written +to `experiment_summary.json` (updated after each completed run) and a comparison +loss-curve image showing all runs on the same axes. + +JSON format: + { + "name": "ti_sweep_1", + "description": "optional human note", + "data_dir": "dataset/bj_sounds", + "output_root": "ti_output/sweep_1", + "base": { + "n_tokens": 4, + "lr": 1e-3, + "steps": 3000, + "batch_size": 16, + "warmup_steps": 100, + "seed": 42, + "save_every": 1000 + }, + "experiments": [ + {"id": "baseline", "description": "default 4 tokens"}, + {"id": "n8_tokens", "n_tokens": 8}, + {"id": "lr_5e4", "lr": 5e-4}, + {"id": "warm_init", "init_text": "industrial sound design"}, + {"id": "n4_more_steps", "steps": 5000} + ] + } +""" + +import json +import sys +import time +import traceback +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import torch + +import comfy.utils +import folder_paths + +from .utils import SELVA_CATEGORY, get_device +from .selva_lora_trainer import ( + _prepare_dataset, + _smooth_losses, + _pil_to_tensor, +) +from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer + + +# --------------------------------------------------------------------------- +# Helpers (shared with LoRA scheduler, inlined to keep modules independent) +# --------------------------------------------------------------------------- + +def _get_system_info() -> dict: + info: dict = { + "torch_version": torch.__version__, + "cuda_version": torch.version.cuda or "N/A", + "gpu_name": None, + "gpu_vram_gb": None, + } + if torch.cuda.is_available(): + try: + info["gpu_name"] = torch.cuda.get_device_name(0) + props = torch.cuda.get_device_properties(0) + info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1) + except Exception: + pass + return info + + +_PARAM_DEFAULTS = { + "n_tokens": 4, + "lr": 1e-3, + "steps": 3000, + "batch_size": 16, + "warmup_steps": 100, + "seed": 42, + "save_every": 1000, + "init_text": "", +} + +_PALETTE = [ + (66, 133, 244), + (234, 67, 53), + (52, 168, 83), + (251, 188, 5), + (155, 89, 182), + (26, 188, 156), + (230, 126, 34), + (149, 165, 166), +] + + +def _resolve_path(raw: str) -> Path: + p = Path(raw.strip()) + unix_style_on_windows = ( + sys.platform == "win32" and p.is_absolute() and not p.drive + ) + if not p.is_absolute() or unix_style_on_windows: + p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor) + return p + + +def _merge_config(base: dict, experiment: dict) -> dict: + cfg = dict(_PARAM_DEFAULTS) + cfg.update(base) + cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")}) + return cfg + + +def _loss_at_steps(loss_history: list, log_interval: int, save_every: int, + total_steps: int) -> dict: + result = {} + for target in range(save_every, total_steps + 1, save_every): + idx = target // log_interval - 1 + if 0 <= idx < len(loss_history): + result[str(target)] = round(loss_history[idx], 6) + return result + + +def _draw_comparison_curves(experiments_data: list) -> "Image.Image": + from PIL import Image, ImageDraw + + W, H = 900, 420 + pl, pr, pt, pb = 75, 160, 30, 50 + + img = Image.new("RGB", (W, H), (255, 255, 255)) + draw = ImageDraw.Draw(img) + pw = W - pl - pr + ph = H - pt - pb + + series = [] + for i, ed in enumerate(experiments_data): + lh = ed.get("loss_history") or [] + if len(lh) < 2: + continue + sm = _smooth_losses(lh) + series.append({ + "id": ed["id"], + "smoothed": sm, + "color": _PALETTE[i % len(_PALETTE)], + }) + + if not series: + draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80)) + return img + + all_vals = [v for s in series for v in s["smoothed"]] + lo, hi = min(all_vals), max(all_vals) + if hi == lo: + hi = lo + 1e-6 + rng = hi - lo + + for i in range(5): + y = pt + int(i * ph / 4) + val = hi - i * rng / 4 + draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1) + draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100)) + + for s in series: + n = len(s["smoothed"]) + pts = [] + for j, v in enumerate(s["smoothed"]): + x = pl + int(j * pw / max(n - 1, 1)) + y = pt + int((1.0 - (v - lo) / rng) * ph) + pts.append((x, y)) + draw.line(pts, fill=s["color"], width=2) + + draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1) + draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1) + draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40)) + + lx, ly = W - pr + 10, pt + for s in series: + draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"]) + draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40)) + ly += 20 + + return img + + +# --------------------------------------------------------------------------- +# Node +# --------------------------------------------------------------------------- + +class SelvaTiScheduler: + """Runs a sweep of Textual Inversion experiments defined in a JSON file. + + The dataset is loaded once and reused. Each experiment calls + SelvaTextualInversionTrainer._train_inner() with its own config. + Results are written to experiment_summary.json after every completed run. + """ + + OUTPUT_NODE = True + CATEGORY = SELVA_CATEGORY + FUNCTION = "run" + RETURN_TYPES = ("STRING", "IMAGE") + RETURN_NAMES = ("summary_path", "comparison_curves") + OUTPUT_TOOLTIPS = ( + "Path to experiment_summary.json — compare runs across sweeps.", + "All smoothed loss curves overlaid on the same axes.", + ) + DESCRIPTION = ( + "Runs a series of Textual Inversion experiments from a JSON sweep file. " + "The dataset is encoded once and reused. Results (loss, config, embeddings " + "paths) are collected in experiment_summary.json after each run." + ) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "experiments_file": ("STRING", { + "default": "ti_experiments.json", + "tooltip": ( + "Path to JSON sweep file. Relative paths resolve to the ComfyUI " + "output directory. See node description for the file format." + ), + }), + } + } + + def run(self, model, experiments_file): + # ------------------------------------------------------------------ + # 1. Read + validate JSON + # ------------------------------------------------------------------ + exp_path = Path(experiments_file.strip()) + if not exp_path.is_absolute(): + candidate = Path(folder_paths.models_dir) / exp_path + if not candidate.exists(): + candidate = Path(folder_paths.get_output_directory()) / exp_path + exp_path = candidate + if not exp_path.exists(): + raise FileNotFoundError( + f"[TI Scheduler] Experiment file not found: {exp_path}" + ) + + spec = json.loads(exp_path.read_text(encoding="utf-8")) + + if "experiments" not in spec or not spec["experiments"]: + raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.") + for i, exp in enumerate(spec["experiments"]): + if "id" not in exp: + raise ValueError( + f"[TI Scheduler] Experiment at index {i} is missing required 'id' field." + ) + + sweep_name = spec.get("name", exp_path.stem) + description = spec.get("description", "") + base_cfg = spec.get("base", {}) + + # ------------------------------------------------------------------ + # 2. Resolve data_dir and output_root + # ------------------------------------------------------------------ + if "data_dir" not in spec: + raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.") + data_dir = _resolve_path(spec["data_dir"]) + output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}")) + output_root.mkdir(parents=True, exist_ok=True) + + device = get_device() + dtype = model["dtype"] + mode = model["mode"] + seq_cfg = model["seq_cfg"] + feature_utils_orig = model["feature_utils"] + + print(f"\n[TI Scheduler] Sweep '{sweep_name}': " + f"{len(spec['experiments'])} experiment(s)", flush=True) + if description: + print(f"[TI Scheduler] {description}", flush=True) + print(f"[TI Scheduler] data_dir = {data_dir}", flush=True) + print(f"[TI Scheduler] output_root = {output_root}\n", flush=True) + + # ------------------------------------------------------------------ + # 3. Load dataset once + # ------------------------------------------------------------------ + n_clips = len(list(data_dir.glob("*.npz"))) + dataset = _prepare_dataset(model, data_dir, device) + + # ------------------------------------------------------------------ + # 4. Build or restore summary (resume-aware) + # ------------------------------------------------------------------ + summary_path = output_root / "experiment_summary.json" + completed_ids = set() + all_curve_data = [] + + if summary_path.exists(): + try: + existing = json.loads(summary_path.read_text(encoding="utf-8")) + for rec in existing.get("experiments", []): + if rec.get("results", {}).get("status") == "completed": + completed_ids.add(rec["id"]) + all_curve_data.append({ + "id": rec["id"], + "loss_history": rec["results"].get("loss_history", []), + }) + summary = existing + summary["completed_at"] = None + if completed_ids: + print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} " + f"completed experiment(s): {sorted(completed_ids)}", flush=True) + except Exception as e: + print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh", + flush=True) + completed_ids = set() + all_curve_data = [] + summary = None + + if not completed_ids: + summary = { + "sweep_name": sweep_name, + "description": description, + "sweep_file": str(exp_path), + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": None, + "system": _get_system_info(), + "data_dir": str(data_dir), + "n_clips": n_clips, + "experiments": [], + } + + def _write_summary(): + summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + + _write_summary() + + # ------------------------------------------------------------------ + # 5. Run each experiment + # ------------------------------------------------------------------ + trainer = SelvaTextualInversionTrainer() + pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"])) + log_interval = 50 # matches _train_inner + + for exp in spec["experiments"]: + exp_id = exp["id"] + exp_desc = exp.get("description", "") + + if exp_id in completed_ids: + print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True) + pbar_outer.update(1) + continue + + cfg = _merge_config(base_cfg, exp) + + n_tokens = int(cfg["n_tokens"]) + lr = float(cfg["lr"]) + steps = int(cfg["steps"]) + batch_size = int(cfg["batch_size"]) + warmup = int(cfg["warmup_steps"]) + seed = int(cfg["seed"]) + save_every = int(cfg["save_every"]) + init_text = str(cfg["init_text"]) + + output_dir = output_root / exp_id + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / "embeddings.pt" + + print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True) + if exp_desc: + print(f"[TI Scheduler] {exp_desc}", flush=True) + print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} " + f"batch_size={batch_size} warmup={warmup} seed={seed}", flush=True) + if init_text: + print(f"[TI Scheduler] init_text='{init_text}'", flush=True) + + exp_record = { + "id": exp_id, + "description": exp_desc, + "config": { + "n_tokens": n_tokens, + "lr": lr, + "steps": steps, + "batch_size": batch_size, + "warmup_steps": warmup, + "seed": seed, + "save_every": save_every, + "init_text": init_text, + }, + "results": {"status": "running"}, + "embeddings_path": None, + "output_dir": str(output_dir), + } + summary["experiments"].append(exp_record) + _write_summary() + + t_start = time.monotonic() + try: + with torch.inference_mode(False), torch.enable_grad(): + r = trainer._train_inner( + model, dataset, feature_utils_orig, seq_cfg, + device, dtype, mode, + data_dir, out_path, + n_tokens, steps, lr, batch_size, + warmup, seed, save_every, init_text, + ) + + duration = time.monotonic() - t_start + loss_history = r["loss_history"] + smoothed = _smooth_losses(loss_history) if loss_history else [] + + final_loss = round(smoothed[-1], 6) if smoothed else None + min_loss = round(min(smoothed), 6) if smoothed else None + min_idx = smoothed.index(min(smoothed)) if smoothed else None + min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None + + loss_std_last_quarter = None + if loss_history: + quarter = max(1, len(loss_history) // 4) + loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6) + + exp_record["results"] = { + "status": "completed", + "final_loss": final_loss, + "min_loss": min_loss, + "min_loss_step": min_loss_step, + "loss_std_last_quarter": loss_std_last_quarter, + "loss_at_steps": _loss_at_steps( + loss_history, log_interval, save_every, steps + ), + "loss_history": [round(v, 6) for v in loss_history], + "log_interval": log_interval, + "duration_seconds": round(duration, 1), + } + exp_record["embeddings_path"] = r["embeddings_path"] + + all_curve_data.append({ + "id": exp_id, + "loss_history": loss_history, + }) + + except Exception as e: + duration = time.monotonic() - t_start + print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True) + traceback.print_exc() + exp_record["results"] = { + "status": "failed", + "error": str(e), + "duration_seconds": round(duration, 1), + } + _write_summary() + pbar_outer.update(1) + continue + + _write_summary() + pbar_outer.update(1) + + # ------------------------------------------------------------------ + # 6. Finalise + # ------------------------------------------------------------------ + summary["completed_at"] = datetime.now(timezone.utc).isoformat() + _write_summary() + print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True) + + # ------------------------------------------------------------------ + # 7. Comparison image + # ------------------------------------------------------------------ + comparison_img = _draw_comparison_curves(all_curve_data) + comparison_img.save(str(output_root / "loss_comparison.png")) + comparison_tensor = _pil_to_tensor(comparison_img) + + return (str(summary_path), comparison_tensor)