feat: add SelVA LoRA Scheduler node for automated experiment sweeps
- Extract _prepare_dataset() from SelvaLoraTrainer.train() as a module-level function so the dataset can be encoded once and reused across experiments - Change _train_inner() return value from tuple to dict (adds loss_history, meta, completed; train() unpacks for ComfyUI — no change to node outputs) - New SelvaLoraScheduler node: reads a JSON sweep file, runs N experiments sequentially, writes experiment_summary.json (updated after each run) and loss_comparison.png with all smoothed curves overlaid on the same axes - Register SelvaLoraScheduler in nodes/__init__.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+113
-91
@@ -220,6 +220,108 @@ def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
def _prepare_dataset(model: dict, data_dir: Path, device) -> list:
|
||||
"""Load VAE, encode audio clips, load .npz features.
|
||||
|
||||
Returns a list of (latents, clip_features, sync_features, text_clip) CPU tensors.
|
||||
The VAE is freed after encoding. Call this once and reuse the dataset across
|
||||
multiple training jobs (e.g. in the scheduler).
|
||||
"""
|
||||
mode = model["mode"]
|
||||
seq_cfg = model["seq_cfg"]
|
||||
feature_utils_orig = model["feature_utils"]
|
||||
|
||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||
if not vae_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
||||
"Run SelVA Model Loader first to auto-download weights."
|
||||
)
|
||||
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||||
vae_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=str(vae_path),
|
||||
enable_conditions=False,
|
||||
mode=mode,
|
||||
need_vae_encoder=True,
|
||||
).to(device).eval()
|
||||
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
if not npz_files:
|
||||
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
||||
|
||||
prompt_map = _load_prompts(data_dir)
|
||||
default_prompt = data_dir.name
|
||||
|
||||
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
||||
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
||||
dataset = []
|
||||
|
||||
for npz_path in npz_files:
|
||||
audio_path = _find_audio(npz_path)
|
||||
if audio_path is None:
|
||||
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
||||
pbar_load.update(1)
|
||||
continue
|
||||
|
||||
bundle = _load_npz(npz_path)
|
||||
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
||||
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
||||
|
||||
try:
|
||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||
|
||||
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||
audio_b = audio.unsqueeze(0).to(device)
|
||||
dist = vae_utils.encode_audio(audio_b)
|
||||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||
tgt = seq_cfg.latent_seq_len
|
||||
if x1.shape[1] < tgt:
|
||||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||
elif x1.shape[1] > tgt:
|
||||
x1 = x1[:, :tgt, :]
|
||||
|
||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||
|
||||
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
||||
# shorter videos have fewer frames and would cause stack() to fail
|
||||
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||||
c_tgt = seq_cfg.clip_seq_len
|
||||
if clip_f.shape[1] < c_tgt:
|
||||
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||
elif clip_f.shape[1] > c_tgt:
|
||||
clip_f = clip_f[:, :c_tgt, :]
|
||||
|
||||
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||||
s_tgt = seq_cfg.sync_seq_len
|
||||
if sync_f.shape[1] < s_tgt:
|
||||
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||
elif sync_f.shape[1] > s_tgt:
|
||||
sync_f = sync_f[:, :s_tgt, :]
|
||||
|
||||
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||
except Exception as e:
|
||||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
|
||||
pbar_load.update(1)
|
||||
|
||||
# VAE no longer needed — free memory
|
||||
del vae_utils
|
||||
soft_empty_cache()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -358,102 +460,14 @@ class SelvaLoraTrainer:
|
||||
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
||||
target_suffixes = tuple(target.strip().split())
|
||||
|
||||
# --- Load VAE encoder (not present in inference model) ---
|
||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||
if not vae_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
||||
"Run SelVA Model Loader first to auto-download weights."
|
||||
)
|
||||
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||||
vae_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=str(vae_path),
|
||||
enable_conditions=False,
|
||||
mode=mode,
|
||||
need_vae_encoder=True,
|
||||
).to(device).eval()
|
||||
|
||||
# --- Pre-load dataset ---
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
if not npz_files:
|
||||
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
||||
|
||||
prompt_map = _load_prompts(data_dir)
|
||||
default_prompt = data_dir.name
|
||||
|
||||
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
||||
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
||||
dataset = []
|
||||
|
||||
for npz_path in npz_files:
|
||||
audio_path = _find_audio(npz_path)
|
||||
if audio_path is None:
|
||||
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
||||
pbar_load.update(1)
|
||||
continue
|
||||
|
||||
bundle = _load_npz(npz_path)
|
||||
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
||||
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
||||
|
||||
try:
|
||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||
|
||||
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||
audio_b = audio.unsqueeze(0).to(device)
|
||||
dist = vae_utils.encode_audio(audio_b)
|
||||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||
tgt = seq_cfg.latent_seq_len
|
||||
if x1.shape[1] < tgt:
|
||||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||
elif x1.shape[1] > tgt:
|
||||
x1 = x1[:, :tgt, :]
|
||||
|
||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||
|
||||
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
||||
# shorter videos have fewer frames and would cause stack() to fail
|
||||
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||||
c_tgt = seq_cfg.clip_seq_len
|
||||
if clip_f.shape[1] < c_tgt:
|
||||
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||
elif clip_f.shape[1] > c_tgt:
|
||||
clip_f = clip_f[:, :c_tgt, :]
|
||||
|
||||
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||||
s_tgt = seq_cfg.sync_seq_len
|
||||
if sync_f.shape[1] < s_tgt:
|
||||
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||
elif sync_f.shape[1] > s_tgt:
|
||||
sync_f = sync_f[:, :s_tgt, :]
|
||||
|
||||
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||
except Exception as e:
|
||||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
|
||||
pbar_load.update(1)
|
||||
|
||||
# VAE no longer needed — free memory
|
||||
del vae_utils
|
||||
soft_empty_cache()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||
dataset = _prepare_dataset(model, data_dir, device)
|
||||
|
||||
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
|
||||
# can't participate in autograd even with enable_grad — disable inference
|
||||
# mode entirely so deepcopy, apply_lora, and the training loop all run
|
||||
# with a clean autograd context.
|
||||
with torch.inference_mode(False), torch.enable_grad():
|
||||
return self._train_inner(
|
||||
r = self._train_inner(
|
||||
model, dataset, feature_utils_orig, seq_cfg,
|
||||
device, dtype, variant, mode,
|
||||
data_dir, output_dir, steps, rank, lr,
|
||||
@@ -462,6 +476,7 @@ class SelvaLoraTrainer:
|
||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||
lora_dropout, lora_plus_ratio,
|
||||
)
|
||||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||
|
||||
def _train_inner(
|
||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||
@@ -677,4 +692,11 @@ class SelvaLoraTrainer:
|
||||
patched = {**model, "generator": generator}
|
||||
|
||||
loss_curve = _pil_to_tensor(smoothed_img)
|
||||
return (patched, str(final_path), loss_curve)
|
||||
return {
|
||||
"patched_model": patched,
|
||||
"adapter_path": str(final_path),
|
||||
"loss_curve": loss_curve,
|
||||
"loss_history": loss_history,
|
||||
"meta": meta,
|
||||
"completed": True,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user