feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights), 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610) - LoRA+: separate optimizer param groups for lora_A and lora_B with configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354) - Curriculum mode: logit_normal for first N% of steps then uniform, directly addresses early convergence + fine-detail degradation at boundaries (arXiv:2603.12517) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -25,13 +25,14 @@ import torch.nn as nn
|
||||
class LoRALinear(nn.Module):
|
||||
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
||||
|
||||
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank)
|
||||
Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank)
|
||||
|
||||
A is initialised with Kaiming uniform; B is initialised to zero so the
|
||||
adapter contribution starts at zero and does not disturb pretrained behaviour.
|
||||
Dropout is applied only to the LoRA path, not the base linear.
|
||||
"""
|
||||
|
||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
in_f = linear.in_features
|
||||
out_f = linear.out_features
|
||||
@@ -46,16 +47,18 @@ class LoRALinear(nn.Module):
|
||||
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
|
||||
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
|
||||
self.scale = alpha / rank
|
||||
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
|
||||
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
|
||||
return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
rank = self.lora_A.shape[0]
|
||||
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
|
||||
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
|
||||
f"rank={rank}, scale={self.scale:.4f}")
|
||||
f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
|
||||
|
||||
|
||||
def apply_lora(
|
||||
@@ -63,6 +66,7 @@ def apply_lora(
|
||||
rank: int = 16,
|
||||
alpha: float = None,
|
||||
target_suffixes: tuple = ("attn.qkv",),
|
||||
dropout: float = 0.0,
|
||||
) -> int:
|
||||
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
||||
|
||||
@@ -74,6 +78,8 @@ def apply_lora(
|
||||
("attn.qkv",) which targets all SelfAttention QKV
|
||||
projections in the MM-DiT generator.
|
||||
Add "linear1" to also wrap post-attention output projections.
|
||||
dropout: Dropout probability on the LoRA path (not the base linear).
|
||||
0.05–0.1 helps regularize on small datasets.
|
||||
|
||||
Returns:
|
||||
Number of linear layers wrapped.
|
||||
@@ -92,7 +98,7 @@ def apply_lora(
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part)
|
||||
setattr(parent, parts[-1], LoRALinear(module, rank, alpha))
|
||||
setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout))
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
Reference in New Issue
Block a user