Files
ComfyUI-SelVA/train_lora.py
T
Ethanfel eb63c1ead7 feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights),
  0.05–0.1 helps regularize on small datasets (arXiv:2404.09610)
- LoRA+: separate optimizer param groups for lora_A and lora_B with
  configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354)
- Curriculum mode: logit_normal for first N% of steps then uniform,
  directly addresses early convergence + fine-detail degradation at
  boundaries (arXiv:2603.12517)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 12:43:18 +02:00

466 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
LoRA fine-tuning for SelVA / MMAudio generator.
Teaches the model new or partially-known sound classes from custom video+audio pairs.
Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model).
Data layout:
data/my_sound/
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
clip01.wav # paired clean audio (same filename stem, any format)
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
If prompts.txt is absent, the prompt embedded in each .npz is used.
If the .npz has no embedded prompt, the directory name is used as fallback.
Usage:
python train_lora.py \\
--data_dir data/my_sound \\
--output_dir lora_output \\
--variant large_44k \\
--selva_dir /path/to/ComfyUI/models/selva \\
--rank 16 --steps 2000 --lr 1e-4
"""
import argparse
import os
import sys
import random
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import open_clip
from open_clip import create_model_from_pretrained
sys.path.insert(0, os.path.dirname(__file__))
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
}
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def load_prompts(data_dir: Path) -> dict:
"""Load filename → prompt overrides from prompts.txt."""
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def find_audio_for_npz(npz_path: Path) -> Path | None:
"""Find a paired audio file with the same stem as the .npz."""
for ext in _AUDIO_EXTS:
candidate = npz_path.with_suffix(ext)
if candidate.exists():
return candidate
return None
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
waveform, sr = torchaudio.load(str(path))
# Stereo → mono
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
# Resample
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr
).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def load_npz(path: Path) -> dict:
"""Load a feature bundle produced by SelvaFeatureExtractor."""
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
if "variant" in data:
bundle["variant"] = str(data["variant"])
return bundle
# ---------------------------------------------------------------------------
# Feature extraction (audio + text only — visual features come from .npz)
# ---------------------------------------------------------------------------
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
tokens = tokenizer(text).to(device)
with torch.inference_mode():
return clip_model.encode_text(tokens, normalize=True)
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
"""Encode a waveform to the generator's latent space via the VAE.
encode_audio is @inference_mode — .clone() is required before the autograd path.
"""
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
dist = feature_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
parser.add_argument("--output_dir", default="lora_output")
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
parser.add_argument("--rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha (default: rank)")
parser.add_argument("--target", nargs="+", default=["attn.qkv"],
help="Module name suffixes to wrap with LoRA. Also try 'linear1'.")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--warmup_steps",type=int, default=100)
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--save_every", type=int, default=500)
parser.add_argument("--resume", default=None,
help="Path to a step checkpoint (.pt) to resume training from.")
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
help="Spread of logit-normal distribution.")
parser.add_argument("--curriculum_switch", type=float, default=0.6,
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
parser.add_argument("--lora_dropout", type=float, default=0.0,
help="Dropout on the LoRA path only. 0.050.1 helps on small datasets.")
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
args = parser.parse_args()
torch.manual_seed(args.seed)
random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[LoRA] bf16 not supported on this GPU — falling back to fp16")
args.precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
selva_dir = Path(args.selva_dir)
output_dir.mkdir(parents=True, exist_ok=True)
gen_filename, mode = _VARIANTS[args.variant]
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
duration = seq_cfg.duration
sample_rate = seq_cfg.sampling_rate
# --- Weight paths ---
def w(name): return str(selva_dir / name)
def wext(name): return str(selva_dir / "ext" / name)
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
gen_weight = w(gen_filename)
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
if not Path(path).exists():
print(f"[LoRA] Missing weight: {path} ({label})")
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
sys.exit(1)
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
print("[LoRA] Loading CLIP text encoder...")
clip_model = create_model_from_pretrained(
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
).to(device, dtype).eval()
clip_model = patch_clip(clip_model)
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
print("[LoRA] Loading VAE encoder...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_weight,
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device, dtype).eval()
# --- Load generator ---
print(f"[LoRA] Loading generator ({args.variant})...")
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(gen_weight, map_location="cpu", weights_only=False)
)
# --- Apply LoRA ---
n_lora = apply_lora(
net_generator,
rank=args.rank,
alpha=args.alpha,
target_suffixes=tuple(args.target),
dropout=args.lora_dropout,
)
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
if n_lora == 0:
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
sys.exit(1)
# Freeze everything except LoRA params
for name, p in net_generator.named_parameters():
p.requires_grad_("lora_" in name)
trainable = sum(p.numel() for p in net_generator.parameters() if p.requires_grad)
total = sum(p.numel() for p in net_generator.parameters())
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
f"({100 * trainable / total:.2f}%)")
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
print(f"[LoRA] No .npz files found in {data_dir}")
sys.exit(1)
prompt_map = load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
dataset = []
for npz_path in npz_files:
audio_path = find_audio_for_npz(npz_path)
if audio_path is None:
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
continue
bundle = load_npz(npz_path)
# Prompt priority: prompts.txt override > embedded in .npz > directory name
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
try:
audio = load_audio(audio_path, sample_rate, duration)
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
# have fewer frames and would cause stack() to fail during batching
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
if not dataset:
print("[LoRA] No clips could be loaded.")
sys.exit(1)
print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler ---
# LoRA+: separate param groups for A and B with different LRs.
# ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": lora_A_params, "lr": args.lr},
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
], weight_decay=1e-2)
if args.lora_plus_ratio != 1.0:
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
def lr_lambda(step):
if step < args.warmup_steps:
return step / max(1, args.warmup_steps)
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
if "step" not in ckpt:
print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?")
sys.exit(1)
start_step = ckpt["step"]
if start_step >= args.steps:
print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.")
sys.exit(0)
net_generator.load_state_dict(ckpt["state_dict"], strict=False)
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step}{args.steps})")
# --- Training loop ---
net_generator.train()
optimizer.zero_grad()
remaining = args.steps - start_step
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), "
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
_curriculum_switched = False
total_loss = 0.0
for step in range(start_step + 1, args.steps + 1):
batch = random.choices(dataset, k=args.batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
net_generator.normalize(x1)
if args.timestep_mode == "logit_normal" or (
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
):
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
t = torch.sigmoid(u)
else:
t = torch.rand(args.batch_size, device=device, dtype=dtype)
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
_curriculum_switched = True
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
loss.backward()
total_loss += loss.item() * args.grad_accum
if step % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 50 == 0:
avg = total_loss / 50
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}")
total_loss = 0.0
if step % args.save_every == 0 or step == args.steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(net_generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
"timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
},
}, ckpt_path)
print(f"[LoRA] Saved {ckpt_path}")
# Save final adapter with embedded metadata
# Increment filename if a previous final already exists (resume case)
final = output_dir / "adapter_final.pt"
if final.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final = output_dir / f"adapter_final_{i:03d}.pt"
meta = {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
"timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
}
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA] Training complete. Adapter saved to {final}")
if __name__ == "__main__":
main()