feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes
LoRA quality improvements addressing intruder dimension problem: 1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of pretrained weight. Starts on-manifold, eliminates intruder dimensions at init. Base weight stores residual W_res = W - B@A*scale. 2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of alpha/rank. Prevents gradient collapse at high ranks (128+). 3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained LoRA update, gradient-sensitivity reweighting to suppress remaining intruder dimensions. Runs automatically after training completes. 4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder dimensions per arXiv:2410.21228. 5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents erasing learned style weights). 6. random.choices replaced with random.sample when batch_size <= dataset size (eliminates duplicate samples per batch). PiSSA checkpoints include base weights (residual). Loader/evaluator updated to handle both standard and PiSSA checkpoint formats. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -316,10 +316,14 @@ class SelvaLoraEvaluator:
|
|||||||
alpha = float(meta.get("alpha", float(rank)))
|
alpha = float(meta.get("alpha", float(rank)))
|
||||||
target = list(meta.get("target", ["attn.qkv"]))
|
target = list(meta.get("target", ["attn.qkv"]))
|
||||||
dropout = float(meta.get("lora_dropout", 0.0))
|
dropout = float(meta.get("lora_dropout", 0.0))
|
||||||
|
use_rslora = meta.get("use_rslora", False)
|
||||||
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
||||||
|
|
||||||
|
# Always use standard init for loading — PiSSA checkpoints
|
||||||
|
# include linear.weight (residual) in state_dict
|
||||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||||
target_suffixes=tuple(target), dropout=dropout)
|
target_suffixes=tuple(target), dropout=dropout,
|
||||||
|
init_mode="standard", use_rslora=use_rslora)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"apply_lora matched 0 layers (target={target})"
|
f"apply_lora matched 0 layers (target={target})"
|
||||||
|
|||||||
@@ -61,16 +61,23 @@ class SelvaLoraLoader:
|
|||||||
rank = int(meta.get("rank", 16))
|
rank = int(meta.get("rank", 16))
|
||||||
alpha = float(meta.get("alpha", float(rank)))
|
alpha = float(meta.get("alpha", float(rank)))
|
||||||
target = list(meta.get("target", ["attn.qkv"]))
|
target = list(meta.get("target", ["attn.qkv"]))
|
||||||
|
init_mode = meta.get("init_mode", "standard")
|
||||||
|
use_rslora = meta.get("use_rslora", False)
|
||||||
|
|
||||||
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
||||||
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}",
|
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
|
||||||
|
f"init={init_mode} rslora={use_rslora} strength={strength}",
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Shallow-copy the model bundle so the original generator is not mutated
|
# Shallow-copy the model bundle so the original generator is not mutated
|
||||||
patched = {**model}
|
patched = {**model}
|
||||||
generator = copy.deepcopy(model["generator"])
|
generator = copy.deepcopy(model["generator"])
|
||||||
|
|
||||||
n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target))
|
# For PiSSA, use standard init (the base weights will be overwritten
|
||||||
|
# by load_state_dict since the checkpoint includes linear.weight)
|
||||||
|
n = apply_lora(generator, rank=rank, alpha=alpha,
|
||||||
|
target_suffixes=tuple(target),
|
||||||
|
init_mode="standard", use_rslora=use_rslora)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"[SelVA LoRA] No layers matched target={target}. "
|
f"[SelVA LoRA] No layers matched target={target}. "
|
||||||
|
|||||||
+78
-11
@@ -21,7 +21,10 @@ import folder_paths
|
|||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
from selva_core.model.lora import (
|
||||||
|
apply_lora, get_lora_state_dict, get_lora_and_base_state_dict, load_lora,
|
||||||
|
spectral_surgery,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
||||||
@@ -486,8 +489,9 @@ class SelvaLoraTrainer:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"alpha": ("FLOAT", {
|
"alpha": ("FLOAT", {
|
||||||
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
|
"default": 0.0, "min": 0.0, "max": 512.0, "step": 0.5,
|
||||||
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
|
"tooltip": "LoRA alpha. 0 = use 2×rank (fewer intruder dimensions, "
|
||||||
|
"arXiv:2410.21228). Set explicitly to override.",
|
||||||
}),
|
}),
|
||||||
"target": ("STRING", {
|
"target": ("STRING", {
|
||||||
"default": "attn.qkv",
|
"default": "attn.qkv",
|
||||||
@@ -525,13 +529,27 @@ class SelvaLoraTrainer:
|
|||||||
"lora_dropout": ("FLOAT", {
|
"lora_dropout": ("FLOAT", {
|
||||||
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
|
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
|
||||||
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
|
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
|
||||||
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610).",
|
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610). "
|
||||||
|
"Forced to 0 when using PiSSA init.",
|
||||||
}),
|
}),
|
||||||
"lora_plus_ratio": ("FLOAT", {
|
"lora_plus_ratio": ("FLOAT", {
|
||||||
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
|
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
|
||||||
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
||||||
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
||||||
}),
|
}),
|
||||||
|
"init_mode": (["standard", "pissa"], {
|
||||||
|
"default": "pissa",
|
||||||
|
"tooltip": "LoRA initialization mode. "
|
||||||
|
"standard: Kaiming-uniform A + zero B (classic LoRA). "
|
||||||
|
"pissa: A and B from top-r SVD of pretrained weight — starts "
|
||||||
|
"on-manifold, eliminates intruder dimensions (arXiv:2404.02948). "
|
||||||
|
"Recommended for audio generation where off-manifold latents cause noise.",
|
||||||
|
}),
|
||||||
|
"use_rslora": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Rank-stabilized LoRA scaling: alpha/sqrt(rank) instead of alpha/rank. "
|
||||||
|
"Prevents gradient collapse at high ranks (128+). Recommended (arXiv:2312.03732).",
|
||||||
|
}),
|
||||||
"lr_schedule": (["constant", "cosine"], {
|
"lr_schedule": (["constant", "cosine"], {
|
||||||
"default": "constant",
|
"default": "constant",
|
||||||
"tooltip": "LR schedule after warmup. "
|
"tooltip": "LR schedule after warmup. "
|
||||||
@@ -562,7 +580,8 @@ class SelvaLoraTrainer:
|
|||||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant"):
|
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||||||
|
init_mode="pissa", use_rslora=True, lr_schedule="constant"):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -595,7 +614,7 @@ class SelvaLoraTrainer:
|
|||||||
output_dir = _out_p
|
output_dir = _out_p
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
alpha_val = float(alpha) if alpha > 0.0 else float(2 * rank)
|
||||||
target_suffixes = tuple(target.strip().split())
|
target_suffixes = tuple(target.strip().split())
|
||||||
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
dataset = _prepare_dataset(model, data_dir, device)
|
||||||
@@ -613,6 +632,7 @@ class SelvaLoraTrainer:
|
|||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||||
lora_dropout, lora_plus_ratio, lr_schedule,
|
lora_dropout, lora_plus_ratio, lr_schedule,
|
||||||
|
init_mode, use_rslora,
|
||||||
)
|
)
|
||||||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||||
|
|
||||||
@@ -624,19 +644,24 @@ class SelvaLoraTrainer:
|
|||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
||||||
|
init_mode="pissa", use_rslora=True,
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|
||||||
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||||||
target_suffixes=target_suffixes, dropout=lora_dropout)
|
target_suffixes=target_suffixes, dropout=lora_dropout,
|
||||||
|
init_mode=init_mode, use_rslora=use_rslora)
|
||||||
if n_lora == 0:
|
if n_lora == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||||||
"Check the 'target' field."
|
"Check the 'target' field."
|
||||||
)
|
)
|
||||||
|
scale_str = f"alpha/√rank={alpha_val/math.sqrt(rank):.2f}" if use_rslora \
|
||||||
|
else f"alpha/rank={alpha_val/rank:.2f}"
|
||||||
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
|
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
|
||||||
f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True)
|
f"(rank={rank}, alpha={alpha_val}, {scale_str}, "
|
||||||
|
f"init={init_mode}, dropout={lora_dropout})", flush=True)
|
||||||
|
|
||||||
for name, p in generator.named_parameters():
|
for name, p in generator.named_parameters():
|
||||||
p.requires_grad_("lora_" in name)
|
p.requires_grad_("lora_" in name)
|
||||||
@@ -655,7 +680,7 @@ class SelvaLoraTrainer:
|
|||||||
optimizer = torch.optim.AdamW([
|
optimizer = torch.optim.AdamW([
|
||||||
{"params": lora_A_params, "lr": lr},
|
{"params": lora_A_params, "lr": lr},
|
||||||
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
|
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
|
||||||
], weight_decay=1e-2)
|
], weight_decay=0.0)
|
||||||
if lora_plus_ratio != 1.0:
|
if lora_plus_ratio != 1.0:
|
||||||
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
||||||
|
|
||||||
@@ -721,6 +746,8 @@ class SelvaLoraTrainer:
|
|||||||
"lora_dropout": lora_dropout,
|
"lora_dropout": lora_dropout,
|
||||||
"lora_plus_ratio": lora_plus_ratio,
|
"lora_plus_ratio": lora_plus_ratio,
|
||||||
"lr_schedule": lr_schedule,
|
"lr_schedule": lr_schedule,
|
||||||
|
"init_mode": init_mode,
|
||||||
|
"use_rslora": use_rslora,
|
||||||
}
|
}
|
||||||
|
|
||||||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||||
@@ -735,6 +762,9 @@ class SelvaLoraTrainer:
|
|||||||
completed = False
|
completed = False
|
||||||
try:
|
try:
|
||||||
for step in range(start_step + 1, steps + 1):
|
for step in range(start_step + 1, steps + 1):
|
||||||
|
if batch_size <= len(dataset):
|
||||||
|
batch = random.sample(dataset, k=batch_size)
|
||||||
|
else:
|
||||||
batch = random.choices(dataset, k=batch_size)
|
batch = random.choices(dataset, k=batch_size)
|
||||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||||
|
|
||||||
@@ -815,8 +845,11 @@ class SelvaLoraTrainer:
|
|||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
if step % save_every == 0 or step == steps:
|
||||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
|
# PiSSA checkpoints need base weights (residual W_res)
|
||||||
|
sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \
|
||||||
|
else get_lora_state_dict(generator)
|
||||||
torch.save({
|
torch.save({
|
||||||
"state_dict": get_lora_state_dict(generator),
|
"state_dict": sd,
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"scheduler": scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
"step": step,
|
"step": step,
|
||||||
@@ -854,6 +887,38 @@ class SelvaLoraTrainer:
|
|||||||
|
|
||||||
completed = True
|
completed = True
|
||||||
|
|
||||||
|
# ── Post-training Spectral Surgery ────────────────────────────────
|
||||||
|
# Reweight LoRA singular values using gradient sensitivity on the
|
||||||
|
# training set. Suppresses intruder dimensions, amplifies useful ones.
|
||||||
|
# (arXiv:2603.03995). Only run on normal completion.
|
||||||
|
try:
|
||||||
|
print("[LoRA Trainer] Running Spectral Surgery...", flush=True)
|
||||||
|
fm_surgery = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||||
|
|
||||||
|
def _calibration_fn(model_cal, step_idx):
|
||||||
|
sample = dataset[step_idx % len(dataset)]
|
||||||
|
x1_cal, clip_cal, sync_cal, text_cal = sample
|
||||||
|
x1_b = x1_cal.unsqueeze(0).to(device, dtype) if x1_cal.dim() == 2 \
|
||||||
|
else x1_cal.to(device, dtype)
|
||||||
|
x1_b = model_cal.normalize(x1_b.clone())
|
||||||
|
clip_b = clip_cal.to(device, dtype)
|
||||||
|
sync_b = sync_cal.to(device, dtype)
|
||||||
|
text_b = text_cal.to(device, dtype)
|
||||||
|
t = torch.rand(1, device=device, dtype=dtype)
|
||||||
|
x0_b = torch.randn_like(x1_b)
|
||||||
|
xt = fm_surgery.get_conditional_flow(x0_b, x1_b, t)
|
||||||
|
v_pred = model_cal.forward(xt, clip_b, sync_b, text_b, t)
|
||||||
|
cal_loss = fm_surgery.loss(v_pred, x0_b, x1_b).mean()
|
||||||
|
cal_loss.backward()
|
||||||
|
|
||||||
|
n_cal = min(128, len(dataset) * 4)
|
||||||
|
n_surgery = spectral_surgery(generator, _calibration_fn,
|
||||||
|
n_calibration=n_cal)
|
||||||
|
print(f"[LoRA Trainer] Spectral Surgery done: {n_surgery} layers processed.",
|
||||||
|
flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[LoRA Trainer] Spectral Surgery failed (non-fatal): {e}", flush=True)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Save adapter and loss curves whether training completed or was cancelled.
|
# Save adapter and loss curves whether training completed or was cancelled.
|
||||||
# Skip if we never completed a single step (nothing useful to save).
|
# Skip if we never completed a single step (nothing useful to save).
|
||||||
@@ -872,7 +937,9 @@ class SelvaLoraTrainer:
|
|||||||
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||||||
label = f"Cancelled at step {last_step}"
|
label = f"Cancelled at step {last_step}"
|
||||||
|
|
||||||
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
final_sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \
|
||||||
|
else get_lora_state_dict(generator)
|
||||||
|
torch.save({"state_dict": final_sd, "meta": meta}, final_path)
|
||||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||||||
|
|
||||||
|
|||||||
+193
-8
@@ -1,6 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
|
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
|
||||||
|
|
||||||
|
Supports two initialization modes:
|
||||||
|
- **standard**: Kaiming-uniform A, zero B (classic LoRA).
|
||||||
|
- **pissa**: A and B from the top-r SVD of the pretrained weight.
|
||||||
|
Starts on-manifold, eliminates intruder dimensions at init
|
||||||
|
(arXiv:2404.02948, NeurIPS 2024 Spotlight).
|
||||||
|
|
||||||
|
Supports two scaling modes:
|
||||||
|
- **standard**: alpha / rank
|
||||||
|
- **rslora**: alpha / sqrt(rank) — rank-stabilized scaling that prevents
|
||||||
|
gradient collapse at high ranks (arXiv:2312.03732).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
||||||
|
|
||||||
@@ -25,14 +36,16 @@ import torch.nn as nn
|
|||||||
class LoRALinear(nn.Module):
|
class LoRALinear(nn.Module):
|
||||||
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
||||||
|
|
||||||
Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank)
|
Output: base(x) + (dropout(x) @ A.T @ B.T) * scale
|
||||||
|
|
||||||
A is initialised with Kaiming uniform; B is initialised to zero so the
|
Standard init: A is Kaiming uniform, B is zero → adapter starts at zero.
|
||||||
adapter contribution starts at zero and does not disturb pretrained behaviour.
|
PiSSA init: A and B from top-r SVD of pretrained weight → adapter starts
|
||||||
Dropout is applied only to the LoRA path, not the base linear.
|
at the principal components, base weight stores the residual.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0):
|
def __init__(self, linear: nn.Linear, rank: int, alpha: float,
|
||||||
|
dropout: float = 0.0, init_mode: str = "standard",
|
||||||
|
use_rslora: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
in_f = linear.in_features
|
in_f = linear.in_features
|
||||||
out_f = linear.out_features
|
out_f = linear.out_features
|
||||||
@@ -44,11 +57,35 @@ class LoRALinear(nn.Module):
|
|||||||
|
|
||||||
ref_dtype = linear.weight.dtype
|
ref_dtype = linear.weight.dtype
|
||||||
ref_device = linear.weight.device
|
ref_device = linear.weight.device
|
||||||
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
|
|
||||||
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
|
if use_rslora:
|
||||||
|
self.scale = alpha / math.sqrt(rank)
|
||||||
|
else:
|
||||||
self.scale = alpha / rank
|
self.scale = alpha / rank
|
||||||
|
|
||||||
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
|
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
if init_mode == "pissa":
|
||||||
|
# PiSSA: init from top-r SVD of pretrained weight.
|
||||||
|
# SVD in float32 for numerical stability, then cast back.
|
||||||
|
W = linear.weight.data.float() # [out_f, in_f]
|
||||||
|
U, S, Vt = torch.linalg.svd(W, full_matrices=False)
|
||||||
|
|
||||||
|
sqrt_S = S[:rank].sqrt()
|
||||||
|
# A: [rank, in_f], B: [out_f, rank]
|
||||||
|
A_init = sqrt_S.unsqueeze(1) * Vt[:rank, :]
|
||||||
|
B_init = U[:, :rank] * sqrt_S.unsqueeze(0)
|
||||||
|
|
||||||
|
# Residual: W_res = W - B_init @ A_init * scale
|
||||||
|
# so that base(x) + LoRA(x) = W_res@x + (B@A)*scale@x = W@x at init
|
||||||
|
linear.weight.data = (W - B_init @ A_init * self.scale).to(ref_dtype)
|
||||||
|
|
||||||
|
self.lora_A = nn.Parameter(A_init.to(dtype=ref_dtype, device=ref_device))
|
||||||
|
self.lora_B = nn.Parameter(B_init.to(dtype=ref_dtype, device=ref_device))
|
||||||
|
else:
|
||||||
|
# Standard LoRA: Kaiming A, zero B → starts at identity
|
||||||
|
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
|
||||||
|
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -67,6 +104,8 @@ def apply_lora(
|
|||||||
alpha: float = None,
|
alpha: float = None,
|
||||||
target_suffixes: tuple = ("attn.qkv",),
|
target_suffixes: tuple = ("attn.qkv",),
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
|
init_mode: str = "standard",
|
||||||
|
use_rslora: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
||||||
|
|
||||||
@@ -80,6 +119,9 @@ def apply_lora(
|
|||||||
Add "linear1" to also wrap post-attention output projections.
|
Add "linear1" to also wrap post-attention output projections.
|
||||||
dropout: Dropout probability on the LoRA path (not the base linear).
|
dropout: Dropout probability on the LoRA path (not the base linear).
|
||||||
0.05–0.1 helps regularize on small datasets.
|
0.05–0.1 helps regularize on small datasets.
|
||||||
|
Must be 0 when using PiSSA (principal components shouldn't be dropped).
|
||||||
|
init_mode: "standard" (Kaiming/zero) or "pissa" (SVD-based).
|
||||||
|
use_rslora: If True, scale by alpha/sqrt(rank) instead of alpha/rank.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of linear layers wrapped.
|
Number of linear layers wrapped.
|
||||||
@@ -87,6 +129,11 @@ def apply_lora(
|
|||||||
if alpha is None:
|
if alpha is None:
|
||||||
alpha = float(rank)
|
alpha = float(rank)
|
||||||
|
|
||||||
|
if init_mode == "pissa" and dropout > 0.0:
|
||||||
|
print("[LoRA] Warning: dropout forced to 0 for PiSSA init "
|
||||||
|
"(principal components should not be dropped).")
|
||||||
|
dropout = 0.0
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for name, module in list(model.named_modules()):
|
for name, module in list(model.named_modules()):
|
||||||
if not any(name.endswith(s) for s in target_suffixes):
|
if not any(name.endswith(s) for s in target_suffixes):
|
||||||
@@ -98,7 +145,10 @@ def apply_lora(
|
|||||||
parent = model
|
parent = model
|
||||||
for part in parts[:-1]:
|
for part in parts[:-1]:
|
||||||
parent = getattr(parent, part)
|
parent = getattr(parent, part)
|
||||||
setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout))
|
setattr(parent, parts[-1], LoRALinear(
|
||||||
|
module, rank, alpha, dropout=dropout,
|
||||||
|
init_mode=init_mode, use_rslora=use_rslora,
|
||||||
|
))
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
return count
|
return count
|
||||||
@@ -109,6 +159,141 @@ def get_lora_state_dict(model: nn.Module) -> dict:
|
|||||||
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
|
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_and_base_state_dict(model: nn.Module) -> dict:
|
||||||
|
"""Return state dict with LoRA params AND base linear weights.
|
||||||
|
|
||||||
|
Needed for PiSSA checkpoints where the base weight stores the residual
|
||||||
|
(W - top_r(W)*scale), not the original pretrained weight.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LoRALinear):
|
||||||
|
prefix = name + "."
|
||||||
|
result[prefix + "lora_A"] = module.lora_A.data
|
||||||
|
result[prefix + "lora_B"] = module.lora_B.data
|
||||||
|
result[prefix + "linear.weight"] = module.linear.weight.data
|
||||||
|
if module.linear.bias is not None:
|
||||||
|
result[prefix + "linear.bias"] = module.linear.bias.data
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_surgery(
|
||||||
|
model: nn.Module,
|
||||||
|
calibration_fn,
|
||||||
|
n_calibration: int = 128,
|
||||||
|
policy: str = "smooth_abs",
|
||||||
|
):
|
||||||
|
"""Post-training Spectral Surgery: reweight LoRA singular values to suppress
|
||||||
|
intruder dimensions and amplify useful components (arXiv:2603.03995).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model with LoRA applied.
|
||||||
|
calibration_fn: Callable that takes (model, step_idx) and runs one forward+backward
|
||||||
|
pass on a calibration sample. Must call loss.backward().
|
||||||
|
n_calibration: Number of calibration samples to average gradients over.
|
||||||
|
policy: Reweighting policy: "smooth_abs" (recommended), "hard" (binary).
|
||||||
|
|
||||||
|
Modifies LoRA A and B in-place. Returns number of layers processed.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
lora_layers = [(name, mod) for name, mod in model.named_modules()
|
||||||
|
if isinstance(mod, LoRALinear)]
|
||||||
|
|
||||||
|
if not lora_layers:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Accumulate per-layer gradient sensitivity: g_k = u_k^T * (dL/dΔW) * v_k
|
||||||
|
sensitivities = {}
|
||||||
|
for name, mod in lora_layers:
|
||||||
|
sensitivities[name] = None
|
||||||
|
|
||||||
|
for step in range(n_calibration):
|
||||||
|
model.zero_grad()
|
||||||
|
# Enable grad temporarily on LoRA params
|
||||||
|
for _, mod in lora_layers:
|
||||||
|
mod.lora_A.requires_grad_(True)
|
||||||
|
mod.lora_B.requires_grad_(True)
|
||||||
|
|
||||||
|
calibration_fn(model, step)
|
||||||
|
|
||||||
|
for name, mod in lora_layers:
|
||||||
|
A = mod.lora_A.data.float() # [rank, in_f]
|
||||||
|
B = mod.lora_B.data.float() # [out_f, rank]
|
||||||
|
# ΔW = B @ A * scale → gradient dL/dΔW ≈ (dL/dB @ A + B^T @ dL/dA) / 2
|
||||||
|
# Per-component sensitivity: project onto SVD directions
|
||||||
|
delta_W = (B @ A * mod.scale).detach()
|
||||||
|
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
|
||||||
|
r = A.shape[0]
|
||||||
|
U_r, S_r, Vt_r = U[:, :r], S[:r], Vt[:r, :]
|
||||||
|
|
||||||
|
# Compute sensitivity from LoRA gradients
|
||||||
|
if mod.lora_A.grad is not None and mod.lora_B.grad is not None:
|
||||||
|
grad_A = mod.lora_A.grad.float() # [rank, in_f]
|
||||||
|
grad_B = mod.lora_B.grad.float() # [out_f, rank]
|
||||||
|
# dL/d(ΔW) ≈ grad_B @ A + B^T @ grad_A (chain rule through B@A)
|
||||||
|
grad_dW = grad_B @ A + B.T @ grad_A # approximate
|
||||||
|
# Per-component: g_k = u_k^T @ grad_dW @ v_k
|
||||||
|
g = torch.einsum("ik,ij,jk->k", U_r, grad_dW, Vt_r.T) # [r]
|
||||||
|
else:
|
||||||
|
g = torch.zeros(r, device=A.device)
|
||||||
|
|
||||||
|
if sensitivities[name] is None:
|
||||||
|
sensitivities[name] = g
|
||||||
|
else:
|
||||||
|
sensitivities[name] += g
|
||||||
|
|
||||||
|
# Disable grad again
|
||||||
|
for _, mod in lora_layers:
|
||||||
|
mod.lora_A.requires_grad_(False)
|
||||||
|
mod.lora_B.requires_grad_(False)
|
||||||
|
|
||||||
|
# Apply reweighting per layer
|
||||||
|
count = 0
|
||||||
|
for name, mod in lora_layers:
|
||||||
|
g = sensitivities[name] / n_calibration
|
||||||
|
A = mod.lora_A.data.float()
|
||||||
|
B = mod.lora_B.data.float()
|
||||||
|
|
||||||
|
delta_W = B @ A * mod.scale
|
||||||
|
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
|
||||||
|
r = A.shape[0]
|
||||||
|
S_r = S[:r]
|
||||||
|
|
||||||
|
if policy == "hard":
|
||||||
|
# Keep components with positive sensitivity, zero out negative
|
||||||
|
mask = (g > 0).float()
|
||||||
|
else:
|
||||||
|
# smooth_abs: sigmoid-weighted by sensitivity magnitude
|
||||||
|
# Normalize g to [-1, 1] range, apply sigmoid
|
||||||
|
g_norm = g / (g.abs().max() + 1e-8)
|
||||||
|
mask = torch.sigmoid(5.0 * g_norm) # steep sigmoid
|
||||||
|
|
||||||
|
# L1 norm preservation: scale mask so total nuclear norm is preserved
|
||||||
|
mask = mask * (S_r.sum() / (mask * S_r).sum().clamp(min=1e-8))
|
||||||
|
|
||||||
|
# Reconstruct: ΔW' = U_r @ diag(mask * S_r) @ Vt_r
|
||||||
|
S_new = mask * S_r
|
||||||
|
delta_W_new = U[:, :r] @ torch.diag(S_new) @ Vt[:r, :]
|
||||||
|
|
||||||
|
# Factor back into B' @ A' * scale: use SVD of ΔW'/scale
|
||||||
|
dW_unscaled = delta_W_new / mod.scale
|
||||||
|
U2, S2, Vt2 = torch.linalg.svd(dW_unscaled, full_matrices=False)
|
||||||
|
sqrt_S2 = S2[:r].sqrt()
|
||||||
|
A_new = sqrt_S2.unsqueeze(1) * Vt2[:r, :]
|
||||||
|
B_new = U2[:, :r] * sqrt_S2.unsqueeze(0)
|
||||||
|
|
||||||
|
ref_dtype = mod.lora_A.dtype
|
||||||
|
mod.lora_A.data = A_new.to(ref_dtype)
|
||||||
|
mod.lora_B.data = B_new.to(ref_dtype)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
kept = (mask > 0.5).sum().item()
|
||||||
|
print(f"[Spectral Surgery] {name}: kept {kept}/{r} components, "
|
||||||
|
f"sensitivity range [{g.min():.3f}, {g.max():.3f}]", flush=True)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model: nn.Module, state_dict: dict) -> None:
|
def load_lora(model: nn.Module, state_dict: dict) -> None:
|
||||||
"""Load LoRA weights into a model that has already had apply_lora() called.
|
"""Load LoRA weights into a model that has already had apply_lora() called.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user