From 784fb2753f2c5740e09fe4bfe9cf4955f01bb782 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 21:54:36 +0200 Subject: [PATCH] feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes LoRA quality improvements addressing intruder dimension problem: 1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of pretrained weight. Starts on-manifold, eliminates intruder dimensions at init. Base weight stores residual W_res = W - B@A*scale. 2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of alpha/rank. Prevents gradient collapse at high ranks (128+). 3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained LoRA update, gradient-sensitivity reweighting to suppress remaining intruder dimensions. Runs automatically after training completes. 4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder dimensions per arXiv:2410.21228. 5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents erasing learned style weights). 6. random.choices replaced with random.sample when batch_size <= dataset size (eliminates duplicate samples per batch). PiSSA checkpoints include base weights (residual). Loader/evaluator updated to handle both standard and PiSSA checkpoint formats. Co-Authored-By: Claude Opus 4.6 --- nodes/selva_lora_evaluator.py | 14 ++- nodes/selva_lora_loader.py | 17 ++- nodes/selva_lora_trainer.py | 91 +++++++++++++-- selva_core/model/lora.py | 209 ++++++++++++++++++++++++++++++++-- 4 files changed, 297 insertions(+), 34 deletions(-) diff --git a/nodes/selva_lora_evaluator.py b/nodes/selva_lora_evaluator.py index 4360570..01e7fc9 100644 --- a/nodes/selva_lora_evaluator.py +++ b/nodes/selva_lora_evaluator.py @@ -312,14 +312,18 @@ class SelvaLoraEvaluator: state_dict = ckpt meta = {} - rank = int(meta.get("rank", 16)) - alpha = float(meta.get("alpha", float(rank))) - target = list(meta.get("target", ["attn.qkv"])) - dropout = float(meta.get("lora_dropout", 0.0)) + rank = int(meta.get("rank", 16)) + alpha = float(meta.get("alpha", float(rank))) + target = list(meta.get("target", ["attn.qkv"])) + dropout = float(meta.get("lora_dropout", 0.0)) + use_rslora = meta.get("use_rslora", False) record["meta"] = {"rank": rank, "alpha": alpha, "target": target} + # Always use standard init for loading — PiSSA checkpoints + # include linear.weight (residual) in state_dict n = apply_lora(generator, rank=rank, alpha=alpha, - target_suffixes=tuple(target), dropout=dropout) + target_suffixes=tuple(target), dropout=dropout, + init_mode="standard", use_rslora=use_rslora) if n == 0: raise RuntimeError( f"apply_lora matched 0 layers (target={target})" diff --git a/nodes/selva_lora_loader.py b/nodes/selva_lora_loader.py index fa1951a..ad64d50 100644 --- a/nodes/selva_lora_loader.py +++ b/nodes/selva_lora_loader.py @@ -58,19 +58,26 @@ class SelvaLoraLoader: state_dict = checkpoint meta = {} - rank = int(meta.get("rank", 16)) - alpha = float(meta.get("alpha", float(rank))) - target = list(meta.get("target", ["attn.qkv"])) + rank = int(meta.get("rank", 16)) + alpha = float(meta.get("alpha", float(rank))) + target = list(meta.get("target", ["attn.qkv"])) + init_mode = meta.get("init_mode", "standard") + use_rslora = meta.get("use_rslora", False) print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True) - print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}", + print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} " + f"init={init_mode} rslora={use_rslora} strength={strength}", flush=True) # Shallow-copy the model bundle so the original generator is not mutated patched = {**model} generator = copy.deepcopy(model["generator"]) - n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target)) + # For PiSSA, use standard init (the base weights will be overwritten + # by load_state_dict since the checkpoint includes linear.weight) + n = apply_lora(generator, rank=rank, alpha=alpha, + target_suffixes=tuple(target), + init_mode="standard", use_rslora=use_rslora) if n == 0: raise RuntimeError( f"[SelVA LoRA] No layers matched target={target}. " diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 2273632..413e8c6 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -21,7 +21,10 @@ import folder_paths from .utils import SELVA_CATEGORY, get_device, soft_empty_cache from selva_core.model.utils.features_utils import FeaturesUtils from selva_core.model.flow_matching import FlowMatching -from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora +from selva_core.model.lora import ( + apply_lora, get_lora_state_dict, get_lora_and_base_state_dict, load_lora, + spectral_surgery, +) _AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"} @@ -486,8 +489,9 @@ class SelvaLoraTrainer: }, "optional": { "alpha": ("FLOAT", { - "default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5, - "tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).", + "default": 0.0, "min": 0.0, "max": 512.0, "step": 0.5, + "tooltip": "LoRA alpha. 0 = use 2×rank (fewer intruder dimensions, " + "arXiv:2410.21228). Set explicitly to override.", }), "target": ("STRING", { "default": "attn.qkv", @@ -525,13 +529,27 @@ class SelvaLoraTrainer: "lora_dropout": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01, "tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). " - "0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610).", + "0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610). " + "Forced to 0 when using PiSSA init.", }), "lora_plus_ratio": ("FLOAT", { "default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0, "tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. " "1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).", }), + "init_mode": (["standard", "pissa"], { + "default": "pissa", + "tooltip": "LoRA initialization mode. " + "standard: Kaiming-uniform A + zero B (classic LoRA). " + "pissa: A and B from top-r SVD of pretrained weight — starts " + "on-manifold, eliminates intruder dimensions (arXiv:2404.02948). " + "Recommended for audio generation where off-manifold latents cause noise.", + }), + "use_rslora": ("BOOLEAN", { + "default": True, + "tooltip": "Rank-stabilized LoRA scaling: alpha/sqrt(rank) instead of alpha/rank. " + "Prevents gradient collapse at high ranks (128+). Recommended (arXiv:2312.03732).", + }), "lr_schedule": (["constant", "cosine"], { "default": "constant", "tooltip": "LR schedule after warmup. " @@ -562,7 +580,8 @@ class SelvaLoraTrainer: alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, grad_accum=1, save_every=500, resume_path="", seed=42, timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, - lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant"): + lora_dropout=0.0, lora_plus_ratio=1.0, + init_mode="pissa", use_rslora=True, lr_schedule="constant"): torch.manual_seed(seed) random.seed(seed) @@ -595,7 +614,7 @@ class SelvaLoraTrainer: output_dir = _out_p output_dir.mkdir(parents=True, exist_ok=True) - alpha_val = float(alpha) if alpha > 0.0 else float(rank) + alpha_val = float(alpha) if alpha > 0.0 else float(2 * rank) target_suffixes = tuple(target.strip().split()) dataset = _prepare_dataset(model, data_dir, device) @@ -613,6 +632,7 @@ class SelvaLoraTrainer: grad_accum, save_every, resume_path, seed, timestep_mode, logit_normal_sigma, curriculum_switch, lora_dropout, lora_plus_ratio, lr_schedule, + init_mode, use_rslora, ) return (r["patched_model"], r["adapter_path"], r["loss_curve"]) @@ -624,19 +644,24 @@ class SelvaLoraTrainer: grad_accum, save_every, resume_path, seed, timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant", + init_mode="pissa", use_rslora=True, ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) n_lora = apply_lora(generator, rank=rank, alpha=alpha_val, - target_suffixes=target_suffixes, dropout=lora_dropout) + target_suffixes=target_suffixes, dropout=lora_dropout, + init_mode=init_mode, use_rslora=use_rslora) if n_lora == 0: raise RuntimeError( f"[LoRA Trainer] No layers matched target={target_suffixes}. " "Check the 'target' field." ) + scale_str = f"alpha/√rank={alpha_val/math.sqrt(rank):.2f}" if use_rslora \ + else f"alpha/rank={alpha_val/rank:.2f}" print(f"[LoRA Trainer] Wrapped {n_lora} layers " - f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True) + f"(rank={rank}, alpha={alpha_val}, {scale_str}, " + f"init={init_mode}, dropout={lora_dropout})", flush=True) for name, p in generator.named_parameters(): p.requires_grad_("lora_" in name) @@ -655,7 +680,7 @@ class SelvaLoraTrainer: optimizer = torch.optim.AdamW([ {"params": lora_A_params, "lr": lr}, {"params": lora_B_params, "lr": lr * lora_plus_ratio}, - ], weight_decay=1e-2) + ], weight_decay=0.0) if lora_plus_ratio != 1.0: print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True) @@ -721,6 +746,8 @@ class SelvaLoraTrainer: "lora_dropout": lora_dropout, "lora_plus_ratio": lora_plus_ratio, "lr_schedule": lr_schedule, + "init_mode": init_mode, + "use_rslora": use_rslora, } # For curriculum mode: compute the step at which we switch from logit_normal to uniform @@ -735,7 +762,10 @@ class SelvaLoraTrainer: completed = False try: for step in range(start_step + 1, steps + 1): - batch = random.choices(dataset, k=batch_size) + if batch_size <= len(dataset): + batch = random.sample(dataset, k=batch_size) + else: + batch = random.choices(dataset, k=batch_size) x1_list, clip_list, sync_list, text_list = zip(*batch) x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) @@ -815,8 +845,11 @@ class SelvaLoraTrainer: if step % save_every == 0 or step == steps: ckpt_path = output_dir / f"adapter_step{step:05d}.pt" + # PiSSA checkpoints need base weights (residual W_res) + sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \ + else get_lora_state_dict(generator) torch.save({ - "state_dict": get_lora_state_dict(generator), + "state_dict": sd, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "step": step, @@ -854,6 +887,38 @@ class SelvaLoraTrainer: completed = True + # ── Post-training Spectral Surgery ──────────────────────────────── + # Reweight LoRA singular values using gradient sensitivity on the + # training set. Suppresses intruder dimensions, amplifies useful ones. + # (arXiv:2603.03995). Only run on normal completion. + try: + print("[LoRA Trainer] Running Spectral Surgery...", flush=True) + fm_surgery = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) + + def _calibration_fn(model_cal, step_idx): + sample = dataset[step_idx % len(dataset)] + x1_cal, clip_cal, sync_cal, text_cal = sample + x1_b = x1_cal.unsqueeze(0).to(device, dtype) if x1_cal.dim() == 2 \ + else x1_cal.to(device, dtype) + x1_b = model_cal.normalize(x1_b.clone()) + clip_b = clip_cal.to(device, dtype) + sync_b = sync_cal.to(device, dtype) + text_b = text_cal.to(device, dtype) + t = torch.rand(1, device=device, dtype=dtype) + x0_b = torch.randn_like(x1_b) + xt = fm_surgery.get_conditional_flow(x0_b, x1_b, t) + v_pred = model_cal.forward(xt, clip_b, sync_b, text_b, t) + cal_loss = fm_surgery.loss(v_pred, x0_b, x1_b).mean() + cal_loss.backward() + + n_cal = min(128, len(dataset) * 4) + n_surgery = spectral_surgery(generator, _calibration_fn, + n_calibration=n_cal) + print(f"[LoRA Trainer] Spectral Surgery done: {n_surgery} layers processed.", + flush=True) + except Exception as e: + print(f"[LoRA Trainer] Spectral Surgery failed (non-fatal): {e}", flush=True) + finally: # Save adapter and loss curves whether training completed or was cancelled. # Skip if we never completed a single step (nothing useful to save). @@ -872,7 +937,9 @@ class SelvaLoraTrainer: final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt" label = f"Cancelled at step {last_step}" - torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path) + final_sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \ + else get_lora_state_dict(generator) + torch.save({"state_dict": final_sd, "meta": meta}, final_path) (output_dir / "meta.json").write_text(json.dumps(meta, indent=2)) print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True) diff --git a/selva_core/model/lora.py b/selva_core/model/lora.py index 86dd729..922ab0c 100644 --- a/selva_core/model/lora.py +++ b/selva_core/model/lora.py @@ -1,6 +1,17 @@ """ LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator. +Supports two initialization modes: + - **standard**: Kaiming-uniform A, zero B (classic LoRA). + - **pissa**: A and B from the top-r SVD of the pretrained weight. + Starts on-manifold, eliminates intruder dimensions at init + (arXiv:2404.02948, NeurIPS 2024 Spotlight). + +Supports two scaling modes: + - **standard**: alpha / rank + - **rslora**: alpha / sqrt(rank) — rank-stabilized scaling that prevents + gradient collapse at high ranks (arXiv:2312.03732). + Usage: from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora @@ -25,14 +36,16 @@ import torch.nn as nn class LoRALinear(nn.Module): """nn.Linear with a frozen base weight and trainable low-rank A/B matrices. - Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank) + Output: base(x) + (dropout(x) @ A.T @ B.T) * scale - A is initialised with Kaiming uniform; B is initialised to zero so the - adapter contribution starts at zero and does not disturb pretrained behaviour. - Dropout is applied only to the LoRA path, not the base linear. + Standard init: A is Kaiming uniform, B is zero → adapter starts at zero. + PiSSA init: A and B from top-r SVD of pretrained weight → adapter starts + at the principal components, base weight stores the residual. """ - def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0): + def __init__(self, linear: nn.Linear, rank: int, alpha: float, + dropout: float = 0.0, init_mode: str = "standard", + use_rslora: bool = False): super().__init__() in_f = linear.in_features out_f = linear.out_features @@ -42,14 +55,38 @@ class LoRALinear(nn.Module): if linear.bias is not None: linear.bias.requires_grad_(False) - ref_dtype = linear.weight.dtype - ref_device = linear.weight.device - self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device)) - self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device)) - self.scale = alpha / rank + ref_dtype = linear.weight.dtype + ref_device = linear.weight.device + + if use_rslora: + self.scale = alpha / math.sqrt(rank) + else: + self.scale = alpha / rank + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + if init_mode == "pissa": + # PiSSA: init from top-r SVD of pretrained weight. + # SVD in float32 for numerical stability, then cast back. + W = linear.weight.data.float() # [out_f, in_f] + U, S, Vt = torch.linalg.svd(W, full_matrices=False) + + sqrt_S = S[:rank].sqrt() + # A: [rank, in_f], B: [out_f, rank] + A_init = sqrt_S.unsqueeze(1) * Vt[:rank, :] + B_init = U[:, :rank] * sqrt_S.unsqueeze(0) + + # Residual: W_res = W - B_init @ A_init * scale + # so that base(x) + LoRA(x) = W_res@x + (B@A)*scale@x = W@x at init + linear.weight.data = (W - B_init @ A_init * self.scale).to(ref_dtype) + + self.lora_A = nn.Parameter(A_init.to(dtype=ref_dtype, device=ref_device)) + self.lora_B = nn.Parameter(B_init.to(dtype=ref_dtype, device=ref_device)) + else: + # Standard LoRA: Kaiming A, zero B → starts at identity + self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device)) + self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device)) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale @@ -67,6 +104,8 @@ def apply_lora( alpha: float = None, target_suffixes: tuple = ("attn.qkv",), dropout: float = 0.0, + init_mode: str = "standard", + use_rslora: bool = False, ) -> int: """Replace matching nn.Linear layers with LoRALinear in-place. @@ -80,6 +119,9 @@ def apply_lora( Add "linear1" to also wrap post-attention output projections. dropout: Dropout probability on the LoRA path (not the base linear). 0.05–0.1 helps regularize on small datasets. + Must be 0 when using PiSSA (principal components shouldn't be dropped). + init_mode: "standard" (Kaiming/zero) or "pissa" (SVD-based). + use_rslora: If True, scale by alpha/sqrt(rank) instead of alpha/rank. Returns: Number of linear layers wrapped. @@ -87,6 +129,11 @@ def apply_lora( if alpha is None: alpha = float(rank) + if init_mode == "pissa" and dropout > 0.0: + print("[LoRA] Warning: dropout forced to 0 for PiSSA init " + "(principal components should not be dropped).") + dropout = 0.0 + count = 0 for name, module in list(model.named_modules()): if not any(name.endswith(s) for s in target_suffixes): @@ -98,7 +145,10 @@ def apply_lora( parent = model for part in parts[:-1]: parent = getattr(parent, part) - setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout)) + setattr(parent, parts[-1], LoRALinear( + module, rank, alpha, dropout=dropout, + init_mode=init_mode, use_rslora=use_rslora, + )) count += 1 return count @@ -109,6 +159,141 @@ def get_lora_state_dict(model: nn.Module) -> dict: return {k: v for k, v in model.state_dict().items() if "lora_" in k} +def get_lora_and_base_state_dict(model: nn.Module) -> dict: + """Return state dict with LoRA params AND base linear weights. + + Needed for PiSSA checkpoints where the base weight stores the residual + (W - top_r(W)*scale), not the original pretrained weight. + """ + result = {} + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + prefix = name + "." + result[prefix + "lora_A"] = module.lora_A.data + result[prefix + "lora_B"] = module.lora_B.data + result[prefix + "linear.weight"] = module.linear.weight.data + if module.linear.bias is not None: + result[prefix + "linear.bias"] = module.linear.bias.data + return result + + +def spectral_surgery( + model: nn.Module, + calibration_fn, + n_calibration: int = 128, + policy: str = "smooth_abs", +): + """Post-training Spectral Surgery: reweight LoRA singular values to suppress + intruder dimensions and amplify useful components (arXiv:2603.03995). + + Args: + model: Model with LoRA applied. + calibration_fn: Callable that takes (model, step_idx) and runs one forward+backward + pass on a calibration sample. Must call loss.backward(). + n_calibration: Number of calibration samples to average gradients over. + policy: Reweighting policy: "smooth_abs" (recommended), "hard" (binary). + + Modifies LoRA A and B in-place. Returns number of layers processed. + """ + model.eval() + lora_layers = [(name, mod) for name, mod in model.named_modules() + if isinstance(mod, LoRALinear)] + + if not lora_layers: + return 0 + + # Accumulate per-layer gradient sensitivity: g_k = u_k^T * (dL/dΔW) * v_k + sensitivities = {} + for name, mod in lora_layers: + sensitivities[name] = None + + for step in range(n_calibration): + model.zero_grad() + # Enable grad temporarily on LoRA params + for _, mod in lora_layers: + mod.lora_A.requires_grad_(True) + mod.lora_B.requires_grad_(True) + + calibration_fn(model, step) + + for name, mod in lora_layers: + A = mod.lora_A.data.float() # [rank, in_f] + B = mod.lora_B.data.float() # [out_f, rank] + # ΔW = B @ A * scale → gradient dL/dΔW ≈ (dL/dB @ A + B^T @ dL/dA) / 2 + # Per-component sensitivity: project onto SVD directions + delta_W = (B @ A * mod.scale).detach() + U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False) + r = A.shape[0] + U_r, S_r, Vt_r = U[:, :r], S[:r], Vt[:r, :] + + # Compute sensitivity from LoRA gradients + if mod.lora_A.grad is not None and mod.lora_B.grad is not None: + grad_A = mod.lora_A.grad.float() # [rank, in_f] + grad_B = mod.lora_B.grad.float() # [out_f, rank] + # dL/d(ΔW) ≈ grad_B @ A + B^T @ grad_A (chain rule through B@A) + grad_dW = grad_B @ A + B.T @ grad_A # approximate + # Per-component: g_k = u_k^T @ grad_dW @ v_k + g = torch.einsum("ik,ij,jk->k", U_r, grad_dW, Vt_r.T) # [r] + else: + g = torch.zeros(r, device=A.device) + + if sensitivities[name] is None: + sensitivities[name] = g + else: + sensitivities[name] += g + + # Disable grad again + for _, mod in lora_layers: + mod.lora_A.requires_grad_(False) + mod.lora_B.requires_grad_(False) + + # Apply reweighting per layer + count = 0 + for name, mod in lora_layers: + g = sensitivities[name] / n_calibration + A = mod.lora_A.data.float() + B = mod.lora_B.data.float() + + delta_W = B @ A * mod.scale + U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False) + r = A.shape[0] + S_r = S[:r] + + if policy == "hard": + # Keep components with positive sensitivity, zero out negative + mask = (g > 0).float() + else: + # smooth_abs: sigmoid-weighted by sensitivity magnitude + # Normalize g to [-1, 1] range, apply sigmoid + g_norm = g / (g.abs().max() + 1e-8) + mask = torch.sigmoid(5.0 * g_norm) # steep sigmoid + + # L1 norm preservation: scale mask so total nuclear norm is preserved + mask = mask * (S_r.sum() / (mask * S_r).sum().clamp(min=1e-8)) + + # Reconstruct: ΔW' = U_r @ diag(mask * S_r) @ Vt_r + S_new = mask * S_r + delta_W_new = U[:, :r] @ torch.diag(S_new) @ Vt[:r, :] + + # Factor back into B' @ A' * scale: use SVD of ΔW'/scale + dW_unscaled = delta_W_new / mod.scale + U2, S2, Vt2 = torch.linalg.svd(dW_unscaled, full_matrices=False) + sqrt_S2 = S2[:r].sqrt() + A_new = sqrt_S2.unsqueeze(1) * Vt2[:r, :] + B_new = U2[:, :r] * sqrt_S2.unsqueeze(0) + + ref_dtype = mod.lora_A.dtype + mod.lora_A.data = A_new.to(ref_dtype) + mod.lora_B.data = B_new.to(ref_dtype) + count += 1 + + kept = (mask > 0.5).sum().item() + print(f"[Spectral Surgery] {name}: kept {kept}/{r} components, " + f"sensitivity range [{g.min():.3f}, {g.max():.3f}]", flush=True) + + return count + + def load_lora(model: nn.Module, state_dict: dict) -> None: """Load LoRA weights into a model that has already had apply_lora() called.