feat: add SelVA LoRA Scheduler node for automated experiment sweeps
- Extract _prepare_dataset() from SelvaLoraTrainer.train() as a module-level function so the dataset can be encoded once and reused across experiments - Change _train_inner() return value from tuple to dict (adds loss_history, meta, completed; train() unpacks for ComfyUI — no change to node outputs) - New SelvaLoraScheduler node: reads a JSON sweep file, runs N experiments sequentially, writes experiment_summary.json (updated after each run) and loss_comparison.png with all smoothed curves overlaid on the same axes - Register SelvaLoraScheduler in nodes/__init__.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ _NODES = {
|
|||||||
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||||
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
||||||
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
||||||
|
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -0,0 +1,436 @@
|
|||||||
|
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
|
||||||
|
|
||||||
|
Each experiment inherits from a shared `base` config and overrides specific keys.
|
||||||
|
The dataset is loaded once and reused across all experiments. Results are written
|
||||||
|
to `experiment_summary.json` (updated after each completed run) and a comparison
|
||||||
|
loss-curve image showing all runs on the same axes.
|
||||||
|
|
||||||
|
JSON format:
|
||||||
|
{
|
||||||
|
"name": "tier1_sweep",
|
||||||
|
"description": "optional human note",
|
||||||
|
"data_dir": "dataset/dog_bark",
|
||||||
|
"output_root": "lora_output/tier1_sweep",
|
||||||
|
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
|
||||||
|
"experiments": [
|
||||||
|
{"id": "baseline", "description": "..."},
|
||||||
|
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device
|
||||||
|
from .selva_lora_trainer import (
|
||||||
|
SelvaLoraTrainer,
|
||||||
|
_prepare_dataset,
|
||||||
|
_smooth_losses,
|
||||||
|
_pil_to_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
|
||||||
|
_PARAM_DEFAULTS = {
|
||||||
|
"alpha": 0.0,
|
||||||
|
"target": "attn.qkv",
|
||||||
|
"batch_size": 4,
|
||||||
|
"warmup_steps": 100,
|
||||||
|
"grad_accum": 1,
|
||||||
|
"save_every": 500,
|
||||||
|
"resume_path": "",
|
||||||
|
"seed": 42,
|
||||||
|
"timestep_mode": "uniform",
|
||||||
|
"logit_normal_sigma": 1.0,
|
||||||
|
"curriculum_switch": 0.6,
|
||||||
|
"lora_dropout": 0.0,
|
||||||
|
"lora_plus_ratio": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Palette for comparison chart: one color per experiment (cycles if > 8)
|
||||||
|
_PALETTE = [
|
||||||
|
(66, 133, 244), # blue
|
||||||
|
(234, 67, 53), # red
|
||||||
|
(52, 168, 83), # green
|
||||||
|
(251, 188, 5), # yellow
|
||||||
|
(155, 89, 182), # purple
|
||||||
|
(26, 188, 156), # teal
|
||||||
|
(230, 126, 34), # orange
|
||||||
|
(149, 165, 166), # grey
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(raw: str) -> Path:
|
||||||
|
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
|
||||||
|
p = Path(raw.strip())
|
||||||
|
unix_style_on_windows = (
|
||||||
|
sys.platform == "win32" and p.is_absolute() and not p.drive
|
||||||
|
)
|
||||||
|
if not p.is_absolute() or unix_style_on_windows:
|
||||||
|
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_config(base: dict, experiment: dict) -> dict:
|
||||||
|
"""Merge base defaults + file base + experiment overrides."""
|
||||||
|
cfg = dict(_PARAM_DEFAULTS)
|
||||||
|
cfg.update(base)
|
||||||
|
# Don't carry id/description into the training params
|
||||||
|
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
||||||
|
start_step: int, total_steps: int) -> dict:
|
||||||
|
"""Build a dict of {step: loss} at each save_every boundary.
|
||||||
|
|
||||||
|
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
|
||||||
|
start + (i+1)*log_interval].
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
targets = range(save_every, total_steps + 1, save_every)
|
||||||
|
for target in targets:
|
||||||
|
# index of the loss entry nearest to this step
|
||||||
|
idx = (target - start_step) // log_interval - 1
|
||||||
|
if 0 <= idx < len(loss_history):
|
||||||
|
result[str(target)] = round(loss_history[idx], 6)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_comparison_curves(
|
||||||
|
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
|
||||||
|
W, H = 900, 420
|
||||||
|
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
|
||||||
|
|
||||||
|
img = Image.new("RGB", (W, H), (255, 255, 255))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
pw = W - pl - pr
|
||||||
|
ph = H - pt - pb
|
||||||
|
|
||||||
|
# Collect all smoothed series
|
||||||
|
series = []
|
||||||
|
for i, ed in enumerate(experiments_data):
|
||||||
|
lh = ed.get("loss_history") or []
|
||||||
|
if len(lh) < 2:
|
||||||
|
continue
|
||||||
|
sm = _smooth_losses(lh)
|
||||||
|
series.append({
|
||||||
|
"id": ed["id"],
|
||||||
|
"smoothed": sm,
|
||||||
|
"log_interval": ed.get("log_interval", 50),
|
||||||
|
"start_step": ed.get("start_step", 0),
|
||||||
|
"color": _PALETTE[i % len(_PALETTE)],
|
||||||
|
})
|
||||||
|
|
||||||
|
if not series:
|
||||||
|
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
|
||||||
|
return img
|
||||||
|
|
||||||
|
all_vals = [v for s in series for v in s["smoothed"]]
|
||||||
|
lo, hi = min(all_vals), max(all_vals)
|
||||||
|
if hi == lo:
|
||||||
|
hi = lo + 1e-6
|
||||||
|
rng = hi - lo
|
||||||
|
|
||||||
|
# Horizontal grid + y-axis labels
|
||||||
|
for i in range(5):
|
||||||
|
y = pt + int(i * ph / 4)
|
||||||
|
val = hi - i * rng / 4
|
||||||
|
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
||||||
|
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
|
||||||
|
|
||||||
|
# Draw each curve
|
||||||
|
for s in series:
|
||||||
|
n = len(s["smoothed"])
|
||||||
|
pts = []
|
||||||
|
for j, v in enumerate(s["smoothed"]):
|
||||||
|
x = pl + int(j * pw / max(n - 1, 1))
|
||||||
|
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
||||||
|
pts.append((x, y))
|
||||||
|
draw.line(pts, fill=s["color"], width=2)
|
||||||
|
|
||||||
|
# Axes
|
||||||
|
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
|
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
|
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
|
||||||
|
|
||||||
|
# Legend (right side)
|
||||||
|
lx = W - pr + 10
|
||||||
|
ly = pt
|
||||||
|
for s in series:
|
||||||
|
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
|
||||||
|
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
|
||||||
|
ly += 20
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaLoraScheduler:
|
||||||
|
"""Runs a sweep of LoRA training experiments defined in a JSON file.
|
||||||
|
|
||||||
|
The dataset (VAE encoding + .npz loading) is performed once and shared
|
||||||
|
across all experiments. Each experiment deep-copies the generator and trains
|
||||||
|
independently. Results are written to `experiment_summary.json` after every
|
||||||
|
completed run so partial results are preserved if the sweep is interrupted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
FUNCTION = "run"
|
||||||
|
RETURN_TYPES = ("STRING", "IMAGE")
|
||||||
|
RETURN_NAMES = ("summary_path", "comparison_curves")
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"Path to experiment_summary.json — share this file to compare runs.",
|
||||||
|
"All smoothed loss curves overlaid on the same axes.",
|
||||||
|
)
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
|
||||||
|
"The dataset is encoded once and reused across all experiments. "
|
||||||
|
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"experiments_file": ("STRING", {
|
||||||
|
"default": "experiments.json",
|
||||||
|
"tooltip": (
|
||||||
|
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
||||||
|
"models directory; absolute paths are used as-is. "
|
||||||
|
"See LORA_TRAINING.md for the file format."
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self, model, experiments_file):
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 1. Read + validate the JSON file
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
exp_path = Path(experiments_file.strip())
|
||||||
|
if not exp_path.is_absolute():
|
||||||
|
# Try relative to ComfyUI models dir first, then output dir
|
||||||
|
candidate = Path(folder_paths.models_dir) / exp_path
|
||||||
|
if not candidate.exists():
|
||||||
|
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
||||||
|
exp_path = candidate
|
||||||
|
if not exp_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
if "experiments" not in spec or not spec["experiments"]:
|
||||||
|
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
|
||||||
|
for i, exp in enumerate(spec["experiments"]):
|
||||||
|
if "id" not in exp:
|
||||||
|
raise ValueError(
|
||||||
|
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
|
||||||
|
)
|
||||||
|
|
||||||
|
sweep_name = spec.get("name", exp_path.stem)
|
||||||
|
description = spec.get("description", "")
|
||||||
|
base_cfg = spec.get("base", {})
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 2. Resolve data_dir and output_root
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if "data_dir" not in spec:
|
||||||
|
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
|
||||||
|
data_dir = _resolve_path(spec["data_dir"])
|
||||||
|
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
|
||||||
|
output_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
|
||||||
|
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
|
||||||
|
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
||||||
|
if description:
|
||||||
|
print(f"[LoRA Scheduler] {description}", flush=True)
|
||||||
|
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
|
||||||
|
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 3. Load + encode dataset once
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
n_clips = len(list(data_dir.glob("*.npz")))
|
||||||
|
dataset = _prepare_dataset(model, data_dir, device)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 4. Build the summary skeleton (written incrementally)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
summary = {
|
||||||
|
"sweep_name": sweep_name,
|
||||||
|
"description": description,
|
||||||
|
"sweep_file": str(exp_path),
|
||||||
|
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"completed_at": None,
|
||||||
|
"data_dir": str(data_dir),
|
||||||
|
"n_clips": n_clips,
|
||||||
|
"experiments": [],
|
||||||
|
}
|
||||||
|
summary_path = output_root / "experiment_summary.json"
|
||||||
|
|
||||||
|
def _write_summary():
|
||||||
|
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_summary()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 5. Run each experiment
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
trainer = SelvaLoraTrainer()
|
||||||
|
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
||||||
|
all_curve_data = [] # collected for comparison image
|
||||||
|
log_interval = 50 # matches _train_inner
|
||||||
|
|
||||||
|
feature_utils_orig = model["feature_utils"]
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
variant = model["variant"]
|
||||||
|
mode = model["mode"]
|
||||||
|
|
||||||
|
for exp in spec["experiments"]:
|
||||||
|
exp_id = exp["id"]
|
||||||
|
exp_desc = exp.get("description", "")
|
||||||
|
cfg = _merge_config(base_cfg, exp)
|
||||||
|
|
||||||
|
# Required training params
|
||||||
|
steps = int(cfg.get("steps", 2000))
|
||||||
|
rank = int(cfg.get("rank", 16))
|
||||||
|
lr = float(cfg.get("lr", 1e-4))
|
||||||
|
alpha = float(cfg.get("alpha", 0.0))
|
||||||
|
target = str(cfg.get("target", "attn.qkv"))
|
||||||
|
batch_size = int(cfg.get("batch_size", 4))
|
||||||
|
warmup = int(cfg.get("warmup_steps", 100))
|
||||||
|
grad_accum = int(cfg.get("grad_accum", 1))
|
||||||
|
save_every = int(cfg.get("save_every", 500))
|
||||||
|
resume_path = str(cfg.get("resume_path", ""))
|
||||||
|
seed = int(cfg.get("seed", 42))
|
||||||
|
ts_mode = str(cfg.get("timestep_mode", "uniform"))
|
||||||
|
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
|
||||||
|
curr_switch = float(cfg.get("curriculum_switch", 0.6))
|
||||||
|
dropout = float(cfg.get("lora_dropout", 0.0))
|
||||||
|
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
|
||||||
|
alpha_val = alpha if alpha > 0.0 else float(rank)
|
||||||
|
target_suffixes = tuple(target.strip().split())
|
||||||
|
|
||||||
|
output_dir = output_root / exp_id
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
|
||||||
|
if exp_desc:
|
||||||
|
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
|
||||||
|
|
||||||
|
exp_record = {
|
||||||
|
"id": exp_id,
|
||||||
|
"description": exp_desc,
|
||||||
|
"config": {
|
||||||
|
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
|
||||||
|
"batch_size": batch_size, "warmup_steps": warmup,
|
||||||
|
"grad_accum": grad_accum, "save_every": save_every,
|
||||||
|
"seed": seed, "target": list(target_suffixes),
|
||||||
|
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
|
||||||
|
"curriculum_switch": curr_switch,
|
||||||
|
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
|
||||||
|
},
|
||||||
|
"results": {"status": "running"},
|
||||||
|
"adapter_path": None,
|
||||||
|
"output_dir": str(output_dir),
|
||||||
|
}
|
||||||
|
summary["experiments"].append(exp_record)
|
||||||
|
_write_summary()
|
||||||
|
|
||||||
|
t_start = time.monotonic()
|
||||||
|
try:
|
||||||
|
with torch.inference_mode(False), torch.enable_grad():
|
||||||
|
r = trainer._train_inner(
|
||||||
|
model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, batch_size, warmup,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = time.monotonic() - t_start
|
||||||
|
loss_history = r["loss_history"]
|
||||||
|
smoothed = _smooth_losses(loss_history) if loss_history else []
|
||||||
|
|
||||||
|
# Compute summary metrics
|
||||||
|
final_loss = round(smoothed[-1], 6) if smoothed else None
|
||||||
|
min_loss = round(min(smoothed), 6) if smoothed else None
|
||||||
|
min_idx = smoothed.index(min(smoothed)) if smoothed else None
|
||||||
|
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
|
||||||
|
|
||||||
|
exp_record["results"] = {
|
||||||
|
"status": "completed",
|
||||||
|
"final_loss": final_loss,
|
||||||
|
"min_loss": min_loss,
|
||||||
|
"min_loss_step": min_loss_step,
|
||||||
|
"loss_at_steps": _loss_at_steps(
|
||||||
|
loss_history, log_interval, save_every, 0, steps
|
||||||
|
),
|
||||||
|
"duration_seconds": round(duration, 1),
|
||||||
|
}
|
||||||
|
exp_record["adapter_path"] = r["adapter_path"]
|
||||||
|
|
||||||
|
all_curve_data.append({
|
||||||
|
"id": exp_id,
|
||||||
|
"loss_history": loss_history,
|
||||||
|
"log_interval": log_interval,
|
||||||
|
"start_step": 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.monotonic() - t_start
|
||||||
|
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
exp_record["results"] = {
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
"duration_seconds": round(duration, 1),
|
||||||
|
}
|
||||||
|
_write_summary()
|
||||||
|
pbar_outer.update(1)
|
||||||
|
# Continue to next experiment rather than aborting the whole sweep
|
||||||
|
continue
|
||||||
|
|
||||||
|
_write_summary()
|
||||||
|
pbar_outer.update(1)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 6. Finalise summary
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||||
|
_write_summary()
|
||||||
|
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 7. Comparison image
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
comparison_img = _draw_comparison_curves(all_curve_data)
|
||||||
|
comparison_img.save(str(output_root / "loss_comparison.png"))
|
||||||
|
comparison_tensor = _pil_to_tensor(comparison_img)
|
||||||
|
|
||||||
|
return (str(summary_path), comparison_tensor)
|
||||||
+113
-91
@@ -220,6 +220,108 @@ def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
|
|||||||
return torch.from_numpy(arr).unsqueeze(0)
|
return torch.from_numpy(arr).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_dataset(model: dict, data_dir: Path, device) -> list:
|
||||||
|
"""Load VAE, encode audio clips, load .npz features.
|
||||||
|
|
||||||
|
Returns a list of (latents, clip_features, sync_features, text_clip) CPU tensors.
|
||||||
|
The VAE is freed after encoding. Call this once and reuse the dataset across
|
||||||
|
multiple training jobs (e.g. in the scheduler).
|
||||||
|
"""
|
||||||
|
mode = model["mode"]
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
feature_utils_orig = model["feature_utils"]
|
||||||
|
|
||||||
|
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||||
|
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||||
|
if not vae_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
||||||
|
"Run SelVA Model Loader first to auto-download weights."
|
||||||
|
)
|
||||||
|
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||||
|
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||||||
|
vae_utils = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=str(vae_path),
|
||||||
|
enable_conditions=False,
|
||||||
|
mode=mode,
|
||||||
|
need_vae_encoder=True,
|
||||||
|
).to(device).eval()
|
||||||
|
|
||||||
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
|
if not npz_files:
|
||||||
|
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
||||||
|
|
||||||
|
prompt_map = _load_prompts(data_dir)
|
||||||
|
default_prompt = data_dir.name
|
||||||
|
|
||||||
|
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
||||||
|
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for npz_path in npz_files:
|
||||||
|
audio_path = _find_audio(npz_path)
|
||||||
|
if audio_path is None:
|
||||||
|
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
||||||
|
pbar_load.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
bundle = _load_npz(npz_path)
|
||||||
|
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
||||||
|
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
|
|
||||||
|
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||||||
|
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||||
|
audio_b = audio.unsqueeze(0).to(device)
|
||||||
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
|
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||||
|
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||||
|
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||||
|
tgt = seq_cfg.latent_seq_len
|
||||||
|
if x1.shape[1] < tgt:
|
||||||
|
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||||
|
elif x1.shape[1] > tgt:
|
||||||
|
x1 = x1[:, :tgt, :]
|
||||||
|
|
||||||
|
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||||
|
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||||
|
|
||||||
|
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
||||||
|
# shorter videos have fewer frames and would cause stack() to fail
|
||||||
|
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||||||
|
c_tgt = seq_cfg.clip_seq_len
|
||||||
|
if clip_f.shape[1] < c_tgt:
|
||||||
|
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||||
|
elif clip_f.shape[1] > c_tgt:
|
||||||
|
clip_f = clip_f[:, :c_tgt, :]
|
||||||
|
|
||||||
|
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||||||
|
s_tgt = seq_cfg.sync_seq_len
|
||||||
|
if sync_f.shape[1] < s_tgt:
|
||||||
|
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||||
|
elif sync_f.shape[1] > s_tgt:
|
||||||
|
sync_f = sync_f[:, :s_tgt, :]
|
||||||
|
|
||||||
|
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
pbar_load.update(1)
|
||||||
|
|
||||||
|
# VAE no longer needed — free memory
|
||||||
|
del vae_utils
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||||
|
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Node
|
# Node
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -358,102 +460,14 @@ class SelvaLoraTrainer:
|
|||||||
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
||||||
target_suffixes = tuple(target.strip().split())
|
target_suffixes = tuple(target.strip().split())
|
||||||
|
|
||||||
# --- Load VAE encoder (not present in inference model) ---
|
dataset = _prepare_dataset(model, data_dir, device)
|
||||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
|
||||||
vae_path = _SELVA_DIR / "ext" / vae_name
|
|
||||||
if not vae_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
|
||||||
"Run SelVA Model Loader first to auto-download weights."
|
|
||||||
)
|
|
||||||
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
|
||||||
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
|
||||||
vae_utils = FeaturesUtils(
|
|
||||||
tod_vae_ckpt=str(vae_path),
|
|
||||||
enable_conditions=False,
|
|
||||||
mode=mode,
|
|
||||||
need_vae_encoder=True,
|
|
||||||
).to(device).eval()
|
|
||||||
|
|
||||||
# --- Pre-load dataset ---
|
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
|
||||||
if not npz_files:
|
|
||||||
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
|
||||||
|
|
||||||
prompt_map = _load_prompts(data_dir)
|
|
||||||
default_prompt = data_dir.name
|
|
||||||
|
|
||||||
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
|
||||||
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
|
||||||
dataset = []
|
|
||||||
|
|
||||||
for npz_path in npz_files:
|
|
||||||
audio_path = _find_audio(npz_path)
|
|
||||||
if audio_path is None:
|
|
||||||
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
|
||||||
pbar_load.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
bundle = _load_npz(npz_path)
|
|
||||||
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
|
||||||
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
|
||||||
|
|
||||||
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
|
||||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
|
||||||
audio_b = audio.unsqueeze(0).to(device)
|
|
||||||
dist = vae_utils.encode_audio(audio_b)
|
|
||||||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
|
||||||
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
|
||||||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
|
||||||
tgt = seq_cfg.latent_seq_len
|
|
||||||
if x1.shape[1] < tgt:
|
|
||||||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
|
||||||
elif x1.shape[1] > tgt:
|
|
||||||
x1 = x1[:, :tgt, :]
|
|
||||||
|
|
||||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
|
||||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
|
||||||
|
|
||||||
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
|
||||||
# shorter videos have fewer frames and would cause stack() to fail
|
|
||||||
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
|
||||||
c_tgt = seq_cfg.clip_seq_len
|
|
||||||
if clip_f.shape[1] < c_tgt:
|
|
||||||
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
|
||||||
elif clip_f.shape[1] > c_tgt:
|
|
||||||
clip_f = clip_f[:, :c_tgt, :]
|
|
||||||
|
|
||||||
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
|
||||||
s_tgt = seq_cfg.sync_seq_len
|
|
||||||
if sync_f.shape[1] < s_tgt:
|
|
||||||
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
|
||||||
elif sync_f.shape[1] > s_tgt:
|
|
||||||
sync_f = sync_f[:, :s_tgt, :]
|
|
||||||
|
|
||||||
dataset.append((x1, clip_f, sync_f, text_clip))
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
pbar_load.update(1)
|
|
||||||
|
|
||||||
# VAE no longer needed — free memory
|
|
||||||
del vae_utils
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
if not dataset:
|
|
||||||
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
|
||||||
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
|
||||||
|
|
||||||
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
|
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
|
||||||
# can't participate in autograd even with enable_grad — disable inference
|
# can't participate in autograd even with enable_grad — disable inference
|
||||||
# mode entirely so deepcopy, apply_lora, and the training loop all run
|
# mode entirely so deepcopy, apply_lora, and the training loop all run
|
||||||
# with a clean autograd context.
|
# with a clean autograd context.
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
with torch.inference_mode(False), torch.enable_grad():
|
||||||
return self._train_inner(
|
r = self._train_inner(
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
model, dataset, feature_utils_orig, seq_cfg,
|
||||||
device, dtype, variant, mode,
|
device, dtype, variant, mode,
|
||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
@@ -462,6 +476,7 @@ class SelvaLoraTrainer:
|
|||||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||||
lora_dropout, lora_plus_ratio,
|
lora_dropout, lora_plus_ratio,
|
||||||
)
|
)
|
||||||
|
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||||
|
|
||||||
def _train_inner(
|
def _train_inner(
|
||||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||||
@@ -677,4 +692,11 @@ class SelvaLoraTrainer:
|
|||||||
patched = {**model, "generator": generator}
|
patched = {**model, "generator": generator}
|
||||||
|
|
||||||
loss_curve = _pil_to_tensor(smoothed_img)
|
loss_curve = _pil_to_tensor(smoothed_img)
|
||||||
return (patched, str(final_path), loss_curve)
|
return {
|
||||||
|
"patched_model": patched,
|
||||||
|
"adapter_path": str(final_path),
|
||||||
|
"loss_curve": loss_curve,
|
||||||
|
"loss_history": loss_history,
|
||||||
|
"meta": meta,
|
||||||
|
"completed": True,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user