feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes
LoRA quality improvements addressing intruder dimension problem: 1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of pretrained weight. Starts on-manifold, eliminates intruder dimensions at init. Base weight stores residual W_res = W - B@A*scale. 2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of alpha/rank. Prevents gradient collapse at high ranks (128+). 3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained LoRA update, gradient-sensitivity reweighting to suppress remaining intruder dimensions. Runs automatically after training completes. 4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder dimensions per arXiv:2410.21228. 5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents erasing learned style weights). 6. random.choices replaced with random.sample when batch_size <= dataset size (eliminates duplicate samples per batch). PiSSA checkpoints include base weights (residual). Loader/evaluator updated to handle both standard and PiSSA checkpoint formats. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+79
-12
@@ -21,7 +21,10 @@ import folder_paths
|
||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
||||
from selva_core.model.lora import (
|
||||
apply_lora, get_lora_state_dict, get_lora_and_base_state_dict, load_lora,
|
||||
spectral_surgery,
|
||||
)
|
||||
|
||||
|
||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
||||
@@ -486,8 +489,9 @@ class SelvaLoraTrainer:
|
||||
},
|
||||
"optional": {
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
|
||||
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
|
||||
"default": 0.0, "min": 0.0, "max": 512.0, "step": 0.5,
|
||||
"tooltip": "LoRA alpha. 0 = use 2×rank (fewer intruder dimensions, "
|
||||
"arXiv:2410.21228). Set explicitly to override.",
|
||||
}),
|
||||
"target": ("STRING", {
|
||||
"default": "attn.qkv",
|
||||
@@ -525,13 +529,27 @@ class SelvaLoraTrainer:
|
||||
"lora_dropout": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
|
||||
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
|
||||
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610).",
|
||||
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610). "
|
||||
"Forced to 0 when using PiSSA init.",
|
||||
}),
|
||||
"lora_plus_ratio": ("FLOAT", {
|
||||
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
|
||||
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
||||
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
||||
}),
|
||||
"init_mode": (["standard", "pissa"], {
|
||||
"default": "pissa",
|
||||
"tooltip": "LoRA initialization mode. "
|
||||
"standard: Kaiming-uniform A + zero B (classic LoRA). "
|
||||
"pissa: A and B from top-r SVD of pretrained weight — starts "
|
||||
"on-manifold, eliminates intruder dimensions (arXiv:2404.02948). "
|
||||
"Recommended for audio generation where off-manifold latents cause noise.",
|
||||
}),
|
||||
"use_rslora": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Rank-stabilized LoRA scaling: alpha/sqrt(rank) instead of alpha/rank. "
|
||||
"Prevents gradient collapse at high ranks (128+). Recommended (arXiv:2312.03732).",
|
||||
}),
|
||||
"lr_schedule": (["constant", "cosine"], {
|
||||
"default": "constant",
|
||||
"tooltip": "LR schedule after warmup. "
|
||||
@@ -562,7 +580,8 @@ class SelvaLoraTrainer:
|
||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant"):
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||||
init_mode="pissa", use_rslora=True, lr_schedule="constant"):
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
@@ -595,7 +614,7 @@ class SelvaLoraTrainer:
|
||||
output_dir = _out_p
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
||||
alpha_val = float(alpha) if alpha > 0.0 else float(2 * rank)
|
||||
target_suffixes = tuple(target.strip().split())
|
||||
|
||||
dataset = _prepare_dataset(model, data_dir, device)
|
||||
@@ -613,6 +632,7 @@ class SelvaLoraTrainer:
|
||||
grad_accum, save_every, resume_path, seed,
|
||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||
lora_dropout, lora_plus_ratio, lr_schedule,
|
||||
init_mode, use_rslora,
|
||||
)
|
||||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||
|
||||
@@ -624,19 +644,24 @@ class SelvaLoraTrainer:
|
||||
grad_accum, save_every, resume_path, seed,
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
||||
init_mode="pissa", use_rslora=True,
|
||||
):
|
||||
# --- Prepare generator copy with LoRA ---
|
||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||
|
||||
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||||
target_suffixes=target_suffixes, dropout=lora_dropout)
|
||||
target_suffixes=target_suffixes, dropout=lora_dropout,
|
||||
init_mode=init_mode, use_rslora=use_rslora)
|
||||
if n_lora == 0:
|
||||
raise RuntimeError(
|
||||
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||||
"Check the 'target' field."
|
||||
)
|
||||
scale_str = f"alpha/√rank={alpha_val/math.sqrt(rank):.2f}" if use_rslora \
|
||||
else f"alpha/rank={alpha_val/rank:.2f}"
|
||||
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
|
||||
f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True)
|
||||
f"(rank={rank}, alpha={alpha_val}, {scale_str}, "
|
||||
f"init={init_mode}, dropout={lora_dropout})", flush=True)
|
||||
|
||||
for name, p in generator.named_parameters():
|
||||
p.requires_grad_("lora_" in name)
|
||||
@@ -655,7 +680,7 @@ class SelvaLoraTrainer:
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": lora_A_params, "lr": lr},
|
||||
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
|
||||
], weight_decay=1e-2)
|
||||
], weight_decay=0.0)
|
||||
if lora_plus_ratio != 1.0:
|
||||
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
||||
|
||||
@@ -721,6 +746,8 @@ class SelvaLoraTrainer:
|
||||
"lora_dropout": lora_dropout,
|
||||
"lora_plus_ratio": lora_plus_ratio,
|
||||
"lr_schedule": lr_schedule,
|
||||
"init_mode": init_mode,
|
||||
"use_rslora": use_rslora,
|
||||
}
|
||||
|
||||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||
@@ -735,7 +762,10 @@ class SelvaLoraTrainer:
|
||||
completed = False
|
||||
try:
|
||||
for step in range(start_step + 1, steps + 1):
|
||||
batch = random.choices(dataset, k=batch_size)
|
||||
if batch_size <= len(dataset):
|
||||
batch = random.sample(dataset, k=batch_size)
|
||||
else:
|
||||
batch = random.choices(dataset, k=batch_size)
|
||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||
|
||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||
@@ -815,8 +845,11 @@ class SelvaLoraTrainer:
|
||||
|
||||
if step % save_every == 0 or step == steps:
|
||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||
# PiSSA checkpoints need base weights (residual W_res)
|
||||
sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \
|
||||
else get_lora_state_dict(generator)
|
||||
torch.save({
|
||||
"state_dict": get_lora_state_dict(generator),
|
||||
"state_dict": sd,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"step": step,
|
||||
@@ -854,6 +887,38 @@ class SelvaLoraTrainer:
|
||||
|
||||
completed = True
|
||||
|
||||
# ── Post-training Spectral Surgery ────────────────────────────────
|
||||
# Reweight LoRA singular values using gradient sensitivity on the
|
||||
# training set. Suppresses intruder dimensions, amplifies useful ones.
|
||||
# (arXiv:2603.03995). Only run on normal completion.
|
||||
try:
|
||||
print("[LoRA Trainer] Running Spectral Surgery...", flush=True)
|
||||
fm_surgery = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||
|
||||
def _calibration_fn(model_cal, step_idx):
|
||||
sample = dataset[step_idx % len(dataset)]
|
||||
x1_cal, clip_cal, sync_cal, text_cal = sample
|
||||
x1_b = x1_cal.unsqueeze(0).to(device, dtype) if x1_cal.dim() == 2 \
|
||||
else x1_cal.to(device, dtype)
|
||||
x1_b = model_cal.normalize(x1_b.clone())
|
||||
clip_b = clip_cal.to(device, dtype)
|
||||
sync_b = sync_cal.to(device, dtype)
|
||||
text_b = text_cal.to(device, dtype)
|
||||
t = torch.rand(1, device=device, dtype=dtype)
|
||||
x0_b = torch.randn_like(x1_b)
|
||||
xt = fm_surgery.get_conditional_flow(x0_b, x1_b, t)
|
||||
v_pred = model_cal.forward(xt, clip_b, sync_b, text_b, t)
|
||||
cal_loss = fm_surgery.loss(v_pred, x0_b, x1_b).mean()
|
||||
cal_loss.backward()
|
||||
|
||||
n_cal = min(128, len(dataset) * 4)
|
||||
n_surgery = spectral_surgery(generator, _calibration_fn,
|
||||
n_calibration=n_cal)
|
||||
print(f"[LoRA Trainer] Spectral Surgery done: {n_surgery} layers processed.",
|
||||
flush=True)
|
||||
except Exception as e:
|
||||
print(f"[LoRA Trainer] Spectral Surgery failed (non-fatal): {e}", flush=True)
|
||||
|
||||
finally:
|
||||
# Save adapter and loss curves whether training completed or was cancelled.
|
||||
# Skip if we never completed a single step (nothing useful to save).
|
||||
@@ -872,7 +937,9 @@ class SelvaLoraTrainer:
|
||||
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||||
label = f"Cancelled at step {last_step}"
|
||||
|
||||
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
||||
final_sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \
|
||||
else get_lora_state_dict(generator)
|
||||
torch.save({"state_dict": final_sd, "meta": meta}, final_path)
|
||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user