Files
ComfyUI-SelVA/nodes/selva_lora_trainer.py
T
Ethanfel 6b9adf0816 fix: fall back to soundfile when torchcodec FFmpeg libs are missing
Recent torchaudio defaults to torchcodec as the audio backend, which requires
FFmpeg shared libraries. Falls back to soundfile for envs where torchcodec
can't load (e.g. containerised ComfyUI without system FFmpeg).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:03:57 +02:00

503 lines
20 KiB
Python

import copy
import json
import random
import traceback
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
# ---------------------------------------------------------------------------
# Data helpers (mirror train_lora.py)
# ---------------------------------------------------------------------------
def _load_prompts(data_dir: Path) -> dict:
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def _find_audio(npz_path: Path) -> Path | None:
for ext in _AUDIO_EXTS:
c = npz_path.with_suffix(ext)
if c.exists():
return c
return None
def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
try:
waveform, sr = torchaudio.load(str(path))
except RuntimeError as e:
if "torchcodec" not in str(e).lower() and "libtorchcodec" not in str(e).lower():
raise
# torchcodec unavailable (FFmpeg shared libs missing) — fall back to soundfile
import soundfile as sf
data, sr = sf.read(str(path), always_2d=True) # [frames, channels]
waveform = torch.from_numpy(data.T).float() # [channels, frames]
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def _load_npz(path: Path) -> dict:
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]),
"sync_features": torch.from_numpy(data["sync_features"]),
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
return bundle
# ---------------------------------------------------------------------------
# Eval sample
# ---------------------------------------------------------------------------
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
num_steps: int = 8):
"""Run a quick no-CFG inference pass on a random training clip.
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
Uses fewer ODE steps than inference (8 vs 25) for speed.
"""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_clip,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
orig_device = next(feature_utils_orig.parameters()).device
if orig_device != device:
feature_utils_orig.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
if orig_device != device:
feature_utils_orig.to(orig_device)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
except Exception as e:
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Loss curve rendering
# ---------------------------------------------------------------------------
def _draw_loss_curve(losses: list[float], log_interval: int,
start_step: int = 0) -> Image.Image:
"""Render a loss curve as a PIL Image."""
W, H = 800, 380
pl, pr, pt, pb = 70, 20, 25, 45
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
if len(losses) >= 2:
lo, hi = min(losses), max(losses)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
# Loss line
n = len(losses)
pts = []
for i, v in enumerate(losses):
x = pl + int(i * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=(66, 133, 244), width=2)
# x-axis step labels — account for start_step so resumed runs are correct
first_step = start_step + log_interval
last_step = start_step + n * log_interval
for i in range(5):
x = pl + int(i * pw / 4)
step = int(first_step + i * (last_step - first_step) / 4)
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
return img
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
arr = np.array(img).astype(np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaLoraTrainer:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files and paired audio files.",
}),
"output_dir": ("STRING", {
"default": "lora_output",
"tooltip": "Where to save adapter checkpoints.",
}),
"steps": ("INT", {
"default": 2000, "min": 100, "max": 100000,
"tooltip": "Total training steps.",
}),
"rank": ("INT", {
"default": 16, "min": 1, "max": 128,
"tooltip": "LoRA rank. Higher = more capacity, more VRAM. 16 is a safe default.",
}),
"lr": ("FLOAT", {
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-6,
"tooltip": "Learning rate.",
}),
},
"optional": {
"alpha": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
}),
"target": ("STRING", {
"default": "attn.qkv",
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
}),
"warmup_steps": ("INT", {"default": 500, "min": 0, "max": 5000}),
"grad_accum": ("INT", {"default": 4, "min": 1, "max": 32,
"tooltip": "Gradient accumulation steps."}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"resume_path": ("STRING", {
"default": "",
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
}),
"seed": ("INT", {"default": 42}),
},
}
RETURN_TYPES = ("SELVA_MODEL", "STRING", "IMAGE")
RETURN_NAMES = ("model", "adapter_path", "loss_curve")
OUTPUT_TOOLTIPS = (
"Model with trained LoRA adapter applied — connect directly to Sampler.",
"Path to adapter_final.pt — use with SelVA LoRA Loader in future sessions.",
"Training loss curve.",
)
FUNCTION = "train"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Trains a LoRA adapter on a dataset of .npz feature files + paired audio files. "
"Blocks the queue for the duration of training. "
"Prepare the dataset with SelVA Feature Extractor (set a name to get numbered .npz files) "
"and pair each .npz with a clean audio file of the same stem."
)
def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", warmup_steps=500,
grad_accum=4, save_every=500, resume_path="", seed=42):
torch.manual_seed(seed)
random.seed(seed)
device = get_device()
dtype = model["dtype"]
variant = model["variant"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
data_dir = Path(data_dir.strip())
output_dir = Path(output_dir.strip())
output_dir.mkdir(parents=True, exist_ok=True)
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
target_suffixes = tuple(target.strip().split())
# --- Load VAE encoder (not present in inference model) ---
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device).eval()
# --- Pre-load dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
prompt_map = _load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
pbar_load = comfy.utils.ProgressBar(len(npz_files))
dataset = []
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
pbar_load.update(1)
continue
bundle = _load_npz(npz_path)
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu()
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1)
# VAE no longer needed — free memory
del vae_utils
soft_empty_cache()
if not dataset:
raise ValueError("[LoRA Trainer] No clips could be loaded.")
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
# --- Prepare generator copy with LoRA ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
target_suffixes=target_suffixes)
if n_lora == 0:
raise RuntimeError(
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
"Check the 'target' field."
)
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True)
for name, p in generator.named_parameters():
p.requires_grad_("lora_" in name)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Optimizer + scheduler ---
lora_params = [p for p in generator.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2)
def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if resume_path.strip():
ckpt = torch.load(resume_path.strip(), map_location="cpu", weights_only=False)
if "step" not in ckpt:
raise ValueError("[LoRA Trainer] Checkpoint has no step info.")
start_step = ckpt["step"]
if start_step >= steps:
raise ValueError(
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
)
load_lora(generator, ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()
log_interval = 50
remaining = steps - start_step
pbar_train = comfy.utils.ProgressBar(remaining)
loss_history = []
running_loss = 0.0
meta = {
"variant": variant,
"rank": rank,
"alpha": alpha_val,
"target": list(target_suffixes),
"steps": steps,
}
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps})\n", flush=True)
for step in range(start_step + 1, steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
x1 = x1_cpu.to(device, dtype)
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
generator.normalize(x1)
t = torch.rand(1, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
loss.backward()
running_loss += loss.item() * grad_accum
if step % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % log_interval == 0:
avg = running_loss / log_interval
loss_history.append(avg)
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
running_loss = 0.0
# Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 85)
)
if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": meta,
}, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample next to the checkpoint
wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype)
if wav is not None:
wav_path = output_dir / f"sample_step{step:05d}.wav"
torchaudio.save(str(wav_path), wav, sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
pbar_train.update(1)
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
final_path = output_dir / "adapter_final.pt"
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
# --- Return patched model ---
generator.eval()
generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator}
loss_curve = _pil_to_tensor(_draw_loss_curve(loss_history, log_interval, start_step))
return (patched, str(final_path), loss_curve)