feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes

LoRA quality improvements addressing intruder dimension problem:

1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of
   pretrained weight. Starts on-manifold, eliminates intruder dimensions
   at init. Base weight stores residual W_res = W - B@A*scale.

2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of
   alpha/rank. Prevents gradient collapse at high ranks (128+).

3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained
   LoRA update, gradient-sensitivity reweighting to suppress remaining
   intruder dimensions. Runs automatically after training completes.

4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder
   dimensions per arXiv:2410.21228.

5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents
   erasing learned style weights).

6. random.choices replaced with random.sample when batch_size <= dataset
   size (eliminates duplicate samples per batch).

PiSSA checkpoints include base weights (residual). Loader/evaluator
updated to handle both standard and PiSSA checkpoint formats.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 21:54:36 +02:00
parent ecf828b007
commit 784fb2753f
4 changed files with 297 additions and 34 deletions
+9 -5
View File
@@ -312,14 +312,18 @@ class SelvaLoraEvaluator:
state_dict = ckpt
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
dropout = float(meta.get("lora_dropout", 0.0))
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
dropout = float(meta.get("lora_dropout", 0.0))
use_rslora = meta.get("use_rslora", False)
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
# Always use standard init for loading — PiSSA checkpoints
# include linear.weight (residual) in state_dict
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target), dropout=dropout)
target_suffixes=tuple(target), dropout=dropout,
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"apply_lora matched 0 layers (target={target})"