feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes
LoRA quality improvements addressing intruder dimension problem: 1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of pretrained weight. Starts on-manifold, eliminates intruder dimensions at init. Base weight stores residual W_res = W - B@A*scale. 2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of alpha/rank. Prevents gradient collapse at high ranks (128+). 3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained LoRA update, gradient-sensitivity reweighting to suppress remaining intruder dimensions. Runs automatically after training completes. 4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder dimensions per arXiv:2410.21228. 5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents erasing learned style weights). 6. random.choices replaced with random.sample when batch_size <= dataset size (eliminates duplicate samples per batch). PiSSA checkpoints include base weights (residual). Loader/evaluator updated to handle both standard and PiSSA checkpoint formats. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -312,14 +312,18 @@ class SelvaLoraEvaluator:
|
||||
state_dict = ckpt
|
||||
meta = {}
|
||||
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
dropout = float(meta.get("lora_dropout", 0.0))
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
dropout = float(meta.get("lora_dropout", 0.0))
|
||||
use_rslora = meta.get("use_rslora", False)
|
||||
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
||||
|
||||
# Always use standard init for loading — PiSSA checkpoints
|
||||
# include linear.weight (residual) in state_dict
|
||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||
target_suffixes=tuple(target), dropout=dropout)
|
||||
target_suffixes=tuple(target), dropout=dropout,
|
||||
init_mode="standard", use_rslora=use_rslora)
|
||||
if n == 0:
|
||||
raise RuntimeError(
|
||||
f"apply_lora matched 0 layers (target={target})"
|
||||
|
||||
Reference in New Issue
Block a user