feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes
LoRA quality improvements addressing intruder dimension problem: 1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of pretrained weight. Starts on-manifold, eliminates intruder dimensions at init. Base weight stores residual W_res = W - B@A*scale. 2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of alpha/rank. Prevents gradient collapse at high ranks (128+). 3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained LoRA update, gradient-sensitivity reweighting to suppress remaining intruder dimensions. Runs automatically after training completes. 4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder dimensions per arXiv:2410.21228. 5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents erasing learned style weights). 6. random.choices replaced with random.sample when batch_size <= dataset size (eliminates duplicate samples per batch). PiSSA checkpoints include base weights (residual). Loader/evaluator updated to handle both standard and PiSSA checkpoint formats. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -58,19 +58,26 @@ class SelvaLoraLoader:
|
||||
state_dict = checkpoint
|
||||
meta = {}
|
||||
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
init_mode = meta.get("init_mode", "standard")
|
||||
use_rslora = meta.get("use_rslora", False)
|
||||
|
||||
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
||||
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}",
|
||||
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
|
||||
f"init={init_mode} rslora={use_rslora} strength={strength}",
|
||||
flush=True)
|
||||
|
||||
# Shallow-copy the model bundle so the original generator is not mutated
|
||||
patched = {**model}
|
||||
generator = copy.deepcopy(model["generator"])
|
||||
|
||||
n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target))
|
||||
# For PiSSA, use standard init (the base weights will be overwritten
|
||||
# by load_state_dict since the checkpoint includes linear.weight)
|
||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||
target_suffixes=tuple(target),
|
||||
init_mode="standard", use_rslora=use_rslora)
|
||||
if n == 0:
|
||||
raise RuntimeError(
|
||||
f"[SelVA LoRA] No layers matched target={target}. "
|
||||
|
||||
Reference in New Issue
Block a user