feat: add SelVA LoRA Trainer ComfyUI node
Runs the full training loop inside ComfyUI. Reuses the already-loaded CLIP model from the inference model for text encoding; loads only a minimal VAE encoder separately (freed after dataset pre-loading). Outputs: - SELVA_MODEL with LoRA applied (ready to connect directly to Sampler) - adapter_path STRING (for SelVA LoRA Loader in future sessions) - loss_curve IMAGE (PIL-rendered line chart of training loss per 50 steps) Progress is shown via ComfyUI ProgressBar (two phases: dataset loading, then training steps). Resume is supported via resume_path input. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ _NODES = {
|
|||||||
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||||
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||||
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
||||||
|
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -0,0 +1,411 @@
|
|||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
||||||
|
|
||||||
|
|
||||||
|
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
||||||
|
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data helpers (mirror train_lora.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_prompts(data_dir: Path) -> dict:
|
||||||
|
p = data_dir / "prompts.txt"
|
||||||
|
if not p.exists():
|
||||||
|
return {}
|
||||||
|
mapping = {}
|
||||||
|
for line in p.read_text(encoding="utf-8").splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if ":" in line:
|
||||||
|
fname, prompt = line.split(":", 1)
|
||||||
|
mapping[fname.strip()] = prompt.strip()
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def _find_audio(npz_path: Path) -> Path | None:
|
||||||
|
for ext in _AUDIO_EXTS:
|
||||||
|
c = npz_path.with_suffix(ext)
|
||||||
|
if c.exists():
|
||||||
|
return c
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
||||||
|
waveform, sr = torchaudio.load(str(path))
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(0, keepdim=True)
|
||||||
|
waveform = waveform.squeeze(0).float()
|
||||||
|
if sr != target_sr:
|
||||||
|
waveform = torchaudio.functional.resample(
|
||||||
|
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
|
||||||
|
target_len = int(duration * target_sr)
|
||||||
|
if waveform.shape[0] >= target_len:
|
||||||
|
return waveform[:target_len]
|
||||||
|
return F.pad(waveform, (0, target_len - waveform.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_npz(path: Path) -> dict:
|
||||||
|
data = np.load(str(path), allow_pickle=False)
|
||||||
|
bundle = {
|
||||||
|
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||||
|
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||||
|
}
|
||||||
|
if "prompt" in data:
|
||||||
|
bundle["prompt"] = str(data["prompt"])
|
||||||
|
return bundle
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Loss curve rendering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
||||||
|
"""Render a loss curve as a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
||||||
|
W, H = 800, 380
|
||||||
|
pl, pr, pt, pb = 70, 20, 25, 45 # plot margins
|
||||||
|
|
||||||
|
img = Image.new("RGB", (W, H), (255, 255, 255))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
pw = W - pl - pr # plot area width
|
||||||
|
ph = H - pt - pb # plot area height
|
||||||
|
|
||||||
|
if len(losses) >= 2:
|
||||||
|
lo, hi = min(losses), max(losses)
|
||||||
|
if hi == lo:
|
||||||
|
hi = lo + 1e-6
|
||||||
|
rng = hi - lo
|
||||||
|
|
||||||
|
# Horizontal grid + y-axis labels
|
||||||
|
for i in range(5):
|
||||||
|
y = pt + int(i * ph / 4)
|
||||||
|
val = hi - i * rng / 4
|
||||||
|
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
||||||
|
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
|
||||||
|
|
||||||
|
# Loss line
|
||||||
|
n = len(losses)
|
||||||
|
pts = []
|
||||||
|
for i, v in enumerate(losses):
|
||||||
|
x = pl + int(i * pw / max(n - 1, 1))
|
||||||
|
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
||||||
|
pts.append((x, y))
|
||||||
|
draw.line(pts, fill=(66, 133, 244), width=2)
|
||||||
|
|
||||||
|
# x-axis step labels
|
||||||
|
total_steps = n * log_interval
|
||||||
|
for i in range(5):
|
||||||
|
x = pl + int(i * pw / 4)
|
||||||
|
step = int(i * total_steps / 4)
|
||||||
|
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
|
||||||
|
|
||||||
|
# Axes
|
||||||
|
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
|
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
|
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
|
||||||
|
|
||||||
|
arr = np.array(img).astype(np.float32) / 255.0
|
||||||
|
return torch.from_numpy(arr).unsqueeze(0) # [1, H, W, 3]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SelvaLoraTrainer:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"data_dir": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Directory containing .npz feature files and paired audio files.",
|
||||||
|
}),
|
||||||
|
"output_dir": ("STRING", {
|
||||||
|
"default": "lora_output",
|
||||||
|
"tooltip": "Where to save adapter checkpoints.",
|
||||||
|
}),
|
||||||
|
"steps": ("INT", {
|
||||||
|
"default": 2000, "min": 100, "max": 100000,
|
||||||
|
"tooltip": "Total training steps.",
|
||||||
|
}),
|
||||||
|
"rank": ("INT", {
|
||||||
|
"default": 16, "min": 1, "max": 128,
|
||||||
|
"tooltip": "LoRA rank. Higher = more capacity, more VRAM. 16 is a safe default.",
|
||||||
|
}),
|
||||||
|
"lr": ("FLOAT", {
|
||||||
|
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-6,
|
||||||
|
"tooltip": "Learning rate.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
|
||||||
|
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
|
||||||
|
}),
|
||||||
|
"target": ("STRING", {
|
||||||
|
"default": "attn.qkv",
|
||||||
|
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
|
||||||
|
}),
|
||||||
|
"warmup_steps": ("INT", {"default": 500, "min": 0, "max": 5000}),
|
||||||
|
"grad_accum": ("INT", {"default": 4, "min": 1, "max": 32,
|
||||||
|
"tooltip": "Gradient accumulation steps."}),
|
||||||
|
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||||
|
"resume_path": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||||||
|
}),
|
||||||
|
"seed": ("INT", {"default": 42}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_MODEL", "STRING", "IMAGE")
|
||||||
|
RETURN_NAMES = ("model", "adapter_path", "loss_curve")
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"Model with trained LoRA adapter applied — connect directly to Sampler.",
|
||||||
|
"Path to adapter_final.pt — use with SelVA LoRA Loader in future sessions.",
|
||||||
|
"Training loss curve.",
|
||||||
|
)
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Trains a LoRA adapter on a dataset of .npz feature files + paired audio files. "
|
||||||
|
"Blocks the queue for the duration of training. "
|
||||||
|
"Prepare the dataset with SelVA Feature Extractor (set a name to get numbered .npz files) "
|
||||||
|
"and pair each .npz with a clean audio file of the same stem."
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha=0.0, target="attn.qkv", warmup_steps=500,
|
||||||
|
grad_accum=4, save_every=500, resume_path="", seed=42):
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
variant = model["variant"]
|
||||||
|
mode = model["mode"]
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
feature_utils_orig = model["feature_utils"]
|
||||||
|
|
||||||
|
data_dir = Path(data_dir.strip())
|
||||||
|
output_dir = Path(output_dir.strip())
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
||||||
|
target_suffixes = tuple(target.strip().split())
|
||||||
|
|
||||||
|
# --- Load VAE encoder (not present in inference model) ---
|
||||||
|
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||||
|
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||||
|
if not vae_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
||||||
|
"Run SelVA Model Loader first to auto-download weights."
|
||||||
|
)
|
||||||
|
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||||
|
vae_utils = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=str(vae_path),
|
||||||
|
enable_conditions=False,
|
||||||
|
mode=mode,
|
||||||
|
need_vae_encoder=True,
|
||||||
|
).to(device, dtype).eval()
|
||||||
|
|
||||||
|
# --- Pre-load dataset ---
|
||||||
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
|
if not npz_files:
|
||||||
|
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
||||||
|
|
||||||
|
prompt_map = _load_prompts(data_dir)
|
||||||
|
default_prompt = data_dir.name
|
||||||
|
|
||||||
|
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
||||||
|
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for npz_path in npz_files:
|
||||||
|
audio_path = _find_audio(npz_path)
|
||||||
|
if audio_path is None:
|
||||||
|
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
||||||
|
pbar_load.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
bundle = _load_npz(npz_path)
|
||||||
|
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
||||||
|
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
|
|
||||||
|
# Audio → latent via VAE
|
||||||
|
audio_b = audio.unsqueeze(0).to(device, dtype)
|
||||||
|
with torch.inference_mode():
|
||||||
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
|
x1 = dist.mode().clone().cpu()
|
||||||
|
|
||||||
|
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||||
|
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||||
|
|
||||||
|
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||||
|
|
||||||
|
pbar_load.update(1)
|
||||||
|
|
||||||
|
# VAE no longer needed — free memory
|
||||||
|
del vae_utils
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||||
|
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||||
|
|
||||||
|
# --- Prepare generator copy with LoRA ---
|
||||||
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|
||||||
|
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||||||
|
target_suffixes=target_suffixes)
|
||||||
|
if n_lora == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||||||
|
"Check the 'target' field."
|
||||||
|
)
|
||||||
|
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True)
|
||||||
|
|
||||||
|
for name, p in generator.named_parameters():
|
||||||
|
p.requires_grad_("lora_" in name)
|
||||||
|
|
||||||
|
generator.update_seq_lengths(
|
||||||
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
|
clip_seq_len=seq_cfg.clip_seq_len,
|
||||||
|
sync_seq_len=seq_cfg.sync_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Optimizer + scheduler ---
|
||||||
|
lora_params = [p for p in generator.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2)
|
||||||
|
|
||||||
|
def lr_lambda(s):
|
||||||
|
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
||||||
|
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||||
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||||
|
|
||||||
|
# --- Resume ---
|
||||||
|
start_step = 0
|
||||||
|
if resume_path.strip():
|
||||||
|
ckpt = torch.load(resume_path.strip(), map_location="cpu", weights_only=False)
|
||||||
|
if "step" not in ckpt:
|
||||||
|
raise ValueError("[LoRA Trainer] Checkpoint has no step info.")
|
||||||
|
start_step = ckpt["step"]
|
||||||
|
if start_step >= steps:
|
||||||
|
raise ValueError(
|
||||||
|
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
|
||||||
|
)
|
||||||
|
generator.load_state_dict(ckpt["state_dict"], strict=False)
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
scheduler.load_state_dict(ckpt["scheduler"])
|
||||||
|
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
|
||||||
|
|
||||||
|
# --- Training loop ---
|
||||||
|
generator.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
log_interval = 50
|
||||||
|
remaining = steps - start_step
|
||||||
|
pbar_train = comfy.utils.ProgressBar(remaining)
|
||||||
|
loss_history = []
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"variant": variant,
|
||||||
|
"rank": rank,
|
||||||
|
"alpha": alpha_val,
|
||||||
|
"target": list(target_suffixes),
|
||||||
|
"steps": steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
|
f"(step {start_step + 1} → {steps})\n", flush=True)
|
||||||
|
|
||||||
|
for step in range(start_step + 1, steps + 1):
|
||||||
|
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||||
|
|
||||||
|
x1 = x1_cpu.to(device, dtype)
|
||||||
|
clip_f = clip_f_cpu.to(device, dtype)
|
||||||
|
sync_f = sync_f_cpu.to(device, dtype)
|
||||||
|
text_clip = text_clip_cpu.to(device, dtype)
|
||||||
|
|
||||||
|
generator.normalize(x1)
|
||||||
|
|
||||||
|
t = torch.rand(1, device=device, dtype=dtype)
|
||||||
|
x0 = torch.randn_like(x1)
|
||||||
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
|
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||||||
|
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
||||||
|
loss.backward()
|
||||||
|
running_loss += loss.item() * grad_accum
|
||||||
|
|
||||||
|
if step % grad_accum == 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % log_interval == 0:
|
||||||
|
avg = running_loss / log_interval
|
||||||
|
loss_history.append(avg)
|
||||||
|
lr_now = scheduler.get_last_lr()[0]
|
||||||
|
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
||||||
|
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
if step % save_every == 0 or step == steps:
|
||||||
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
|
torch.save({
|
||||||
|
"state_dict": get_lora_state_dict(generator),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": scheduler.state_dict(),
|
||||||
|
"step": step,
|
||||||
|
"meta": meta,
|
||||||
|
}, ckpt_path)
|
||||||
|
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||||||
|
|
||||||
|
pbar_train.update(1)
|
||||||
|
|
||||||
|
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
||||||
|
final_path = output_dir / "adapter_final.pt"
|
||||||
|
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
||||||
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
|
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
|
||||||
|
|
||||||
|
# --- Return patched model ---
|
||||||
|
generator.eval()
|
||||||
|
generator.to(next(model["generator"].parameters()).device)
|
||||||
|
patched = {**model, "generator": generator}
|
||||||
|
|
||||||
|
loss_curve = _draw_loss_curve(loss_history, log_interval)
|
||||||
|
|
||||||
|
return (patched, str(final_path), loss_curve)
|
||||||
Reference in New Issue
Block a user