feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling

- LoRA dropout: applied to the LoRA path only (not frozen base weights),
  0.05–0.1 helps regularize on small datasets (arXiv:2404.09610)
- LoRA+: separate optimizer param groups for lora_A and lora_B with
  configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354)
- Curriculum mode: logit_normal for first N% of steps then uniform,
  directly addresses early convergence + fine-detail degradation at
  boundaries (arXiv:2603.12517)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 12:43:18 +02:00
parent 5baa070e61
commit eb63c1ead7
3 changed files with 104 additions and 26 deletions
+11 -5
View File
@@ -25,13 +25,14 @@ import torch.nn as nn
class LoRALinear(nn.Module):
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank)
Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank)
A is initialised with Kaiming uniform; B is initialised to zero so the
adapter contribution starts at zero and does not disturb pretrained behaviour.
Dropout is applied only to the LoRA path, not the base linear.
"""
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0):
super().__init__()
in_f = linear.in_features
out_f = linear.out_features
@@ -46,16 +47,18 @@ class LoRALinear(nn.Module):
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
self.scale = alpha / rank
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
def extra_repr(self) -> str:
rank = self.lora_A.shape[0]
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
f"rank={rank}, scale={self.scale:.4f}")
f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
def apply_lora(
@@ -63,6 +66,7 @@ def apply_lora(
rank: int = 16,
alpha: float = None,
target_suffixes: tuple = ("attn.qkv",),
dropout: float = 0.0,
) -> int:
"""Replace matching nn.Linear layers with LoRALinear in-place.
@@ -74,6 +78,8 @@ def apply_lora(
("attn.qkv",) which targets all SelfAttention QKV
projections in the MM-DiT generator.
Add "linear1" to also wrap post-attention output projections.
dropout: Dropout probability on the LoRA path (not the base linear).
0.050.1 helps regularize on small datasets.
Returns:
Number of linear layers wrapped.
@@ -92,7 +98,7 @@ def apply_lora(
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], LoRALinear(module, rank, alpha))
setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout))
count += 1
return count