feat: add SelVA LoRA Scheduler node for automated experiment sweeps

- Extract _prepare_dataset() from SelvaLoraTrainer.train() as a module-level
  function so the dataset can be encoded once and reused across experiments
- Change _train_inner() return value from tuple to dict (adds loss_history,
  meta, completed; train() unpacks for ComfyUI — no change to node outputs)
- New SelvaLoraScheduler node: reads a JSON sweep file, runs N experiments
  sequentially, writes experiment_summary.json (updated after each run) and
  loss_comparison.png with all smoothed curves overlaid on the same axes
- Register SelvaLoraScheduler in nodes/__init__.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 13:03:21 +02:00
parent 9bc2568543
commit 3ec380a27e
3 changed files with 550 additions and 91 deletions
+1
View File
@@ -7,6 +7,7 @@ _NODES = {
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
}
for key, (module_path, class_name, display_name) in _NODES.items():
+436
View File
@@ -0,0 +1,436 @@
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "tier1_sweep",
"description": "optional human note",
"data_dir": "dataset/dog_bark",
"output_root": "lora_output/tier1_sweep",
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
...
]
}
"""
import copy
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
SelvaLoraTrainer,
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"alpha": 0.0,
"target": "attn.qkv",
"batch_size": 4,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"resume_path": "",
"seed": 42,
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
}
# Palette for comparison chart: one color per experiment (cycles if > 8)
_PALETTE = [
(66, 133, 244), # blue
(234, 67, 53), # red
(52, 168, 83), # green
(251, 188, 5), # yellow
(155, 89, 182), # purple
(26, 188, 156), # teal
(230, 126, 34), # orange
(149, 165, 166), # grey
]
def _resolve_path(raw: str) -> Path:
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge base defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
# Don't carry id/description into the training params
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
start_step: int, total_steps: int) -> dict:
"""Build a dict of {step: loss} at each save_every boundary.
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
start + (i+1)*log_interval].
"""
result = {}
targets = range(save_every, total_steps + 1, save_every)
for target in targets:
# index of the loss entry nearest to this step
idx = (target - start_step) // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
) -> Image.Image:
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
# Collect all smoothed series
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"log_interval": ed.get("log_interval", 50),
"start_step": ed.get("start_step", 0),
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
# Draw each curve
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
# Legend (right side)
lx = W - pr + 10
ly = pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
class SelvaLoraScheduler:
"""Runs a sweep of LoRA training experiments defined in a JSON file.
The dataset (VAE encoding + .npz loading) is performed once and shared
across all experiments. Each experiment deep-copies the generator and trains
independently. Results are written to `experiment_summary.json` after every
completed run so partial results are preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
"The dataset is encoded once and reused across all experiments. "
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is. "
"See LORA_TRAINING.md for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
# Try relative to ComfyUI models dir first, then output dir
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[LoRA Scheduler] {description}", flush=True)
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load + encode dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build the summary skeleton (written incrementally)
# ------------------------------------------------------------------
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
summary_path = output_root / "experiment_summary.json"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaLoraTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
all_curve_data = [] # collected for comparison image
log_interval = 50 # matches _train_inner
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
variant = model["variant"]
mode = model["mode"]
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
cfg = _merge_config(base_cfg, exp)
# Required training params
steps = int(cfg.get("steps", 2000))
rank = int(cfg.get("rank", 16))
lr = float(cfg.get("lr", 1e-4))
alpha = float(cfg.get("alpha", 0.0))
target = str(cfg.get("target", "attn.qkv"))
batch_size = int(cfg.get("batch_size", 4))
warmup = int(cfg.get("warmup_steps", 100))
grad_accum = int(cfg.get("grad_accum", 1))
save_every = int(cfg.get("save_every", 500))
resume_path = str(cfg.get("resume_path", ""))
seed = int(cfg.get("seed", 42))
ts_mode = str(cfg.get("timestep_mode", "uniform"))
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
curr_switch = float(cfg.get("curriculum_switch", 0.6))
dropout = float(cfg.get("lora_dropout", 0.0))
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
alpha_val = alpha if alpha > 0.0 else float(rank)
target_suffixes = tuple(target.strip().split())
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
"batch_size": batch_size, "warmup_steps": warmup,
"grad_accum": grad_accum, "save_every": save_every,
"seed": seed, "target": list(target_suffixes),
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
"curriculum_switch": curr_switch,
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
},
"results": {"status": "running"},
"adapter_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup,
grad_accum, save_every, resume_path, seed,
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
# Compute summary metrics
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, 0, steps
),
"duration_seconds": round(duration, 1),
}
exp_record["adapter_path"] = r["adapter_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
# Continue to next experiment rather than aborting the whole sweep
continue
_write_summary()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
+113 -91
View File
@@ -220,6 +220,108 @@ def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
return torch.from_numpy(arr).unsqueeze(0)
def _prepare_dataset(model: dict, data_dir: Path, device) -> list:
"""Load VAE, encode audio clips, load .npz features.
Returns a list of (latents, clip_features, sync_features, text_clip) CPU tensors.
The VAE is freed after encoding. Call this once and reuse the dataset across
multiple training jobs (e.g. in the scheduler).
"""
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device).eval()
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
prompt_map = _load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
pbar_load = comfy.utils.ProgressBar(len(npz_files))
dataset = []
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
pbar_load.update(1)
continue
bundle = _load_npz(npz_path)
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
x1 = dist.mode().clone().transpose(1, 2).cpu()
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
# Pad/trim clip and sync features to fixed seq lengths — clips from
# shorter videos have fewer frames and would cause stack() to fail
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1)
# VAE no longer needed — free memory
del vae_utils
soft_empty_cache()
if not dataset:
raise ValueError("[LoRA Trainer] No clips could be loaded.")
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
return dataset
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
@@ -358,102 +460,14 @@ class SelvaLoraTrainer:
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
target_suffixes = tuple(target.strip().split())
# --- Load VAE encoder (not present in inference model) ---
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device).eval()
# --- Pre-load dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
prompt_map = _load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
pbar_load = comfy.utils.ProgressBar(len(npz_files))
dataset = []
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
pbar_load.update(1)
continue
bundle = _load_npz(npz_path)
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
x1 = dist.mode().clone().transpose(1, 2).cpu()
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
# Pad/trim clip and sync features to fixed seq lengths — clips from
# shorter videos have fewer frames and would cause stack() to fail
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1)
# VAE no longer needed — free memory
del vae_utils
soft_empty_cache()
if not dataset:
raise ValueError("[LoRA Trainer] No clips could be loaded.")
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
dataset = _prepare_dataset(model, data_dir, device)
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
# can't participate in autograd even with enable_grad — disable inference
# mode entirely so deepcopy, apply_lora, and the training loop all run
# with a clean autograd context.
with torch.inference_mode(False), torch.enable_grad():
return self._train_inner(
r = self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
@@ -462,6 +476,7 @@ class SelvaLoraTrainer:
timestep_mode, logit_normal_sigma, curriculum_switch,
lora_dropout, lora_plus_ratio,
)
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
@@ -677,4 +692,11 @@ class SelvaLoraTrainer:
patched = {**model, "generator": generator}
loss_curve = _pil_to_tensor(smoothed_img)
return (patched, str(final_path), loss_curve)
return {
"patched_model": patched,
"adapter_path": str(final_path),
"loss_curve": loss_curve,
"loss_history": loss_history,
"meta": meta,
"completed": True,
}