feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes

LoRA quality improvements addressing intruder dimension problem:

1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of
   pretrained weight. Starts on-manifold, eliminates intruder dimensions
   at init. Base weight stores residual W_res = W - B@A*scale.

2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of
   alpha/rank. Prevents gradient collapse at high ranks (128+).

3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained
   LoRA update, gradient-sensitivity reweighting to suppress remaining
   intruder dimensions. Runs automatically after training completes.

4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder
   dimensions per arXiv:2410.21228.

5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents
   erasing learned style weights).

6. random.choices replaced with random.sample when batch_size <= dataset
   size (eliminates duplicate samples per batch).

PiSSA checkpoints include base weights (residual). Loader/evaluator
updated to handle both standard and PiSSA checkpoint formats.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 21:54:36 +02:00
parent ecf828b007
commit 784fb2753f
4 changed files with 297 additions and 34 deletions
+12 -5
View File
@@ -58,19 +58,26 @@ class SelvaLoraLoader:
state_dict = checkpoint
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
init_mode = meta.get("init_mode", "standard")
use_rslora = meta.get("use_rslora", False)
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}",
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
f"init={init_mode} rslora={use_rslora} strength={strength}",
flush=True)
# Shallow-copy the model bundle so the original generator is not mutated
patched = {**model}
generator = copy.deepcopy(model["generator"])
n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target))
# For PiSSA, use standard init (the base weights will be overwritten
# by load_state_dict since the checkpoint includes linear.weight)
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target),
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"[SelVA LoRA] No layers matched target={target}. "