feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling

- LoRA dropout: applied to the LoRA path only (not frozen base weights),
  0.05–0.1 helps regularize on small datasets (arXiv:2404.09610)
- LoRA+: separate optimizer param groups for lora_A and lora_B with
  configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354)
- Curriculum mode: logit_normal for first N% of steps then uniform,
  directly addresses early convergence + fine-detail degradation at
  boundaries (arXiv:2603.12517)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 12:43:18 +02:00
parent 5baa070e61
commit eb63c1ead7
3 changed files with 104 additions and 26 deletions
+54 -13
View File
@@ -271,17 +271,33 @@ class SelvaLoraTrainer:
"tooltip": "Path to a step checkpoint (.pt) to resume training from.", "tooltip": "Path to a step checkpoint (.pt) to resume training from.",
}), }),
"seed": ("INT", {"default": 42}), "seed": ("INT", {"default": 42}),
"timestep_mode": (["uniform", "logit_normal"], { "timestep_mode": (["uniform", "logit_normal", "curriculum"], {
"default": "uniform", "default": "uniform",
"tooltip": "How to sample training timesteps. " "tooltip": "How to sample training timesteps. "
"uniform samples all timesteps equally (default, matches original MMAudio training). " "uniform: all timesteps equally (matches original MMAudio). "
"logit_normal concentrates steps near t=0.5 — reaches lower loss but perceptual improvement is dataset-dependent.", "logit_normal: concentrates near t=0.5. "
"curriculum: logit_normal for first curriculum_switch% of steps then uniform (recommended for small datasets).",
}), }),
"logit_normal_sigma": ("FLOAT", { "logit_normal_sigma": ("FLOAT", {
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1, "default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
"tooltip": "Spread of the logit-normal distribution. " "tooltip": "Spread of the logit-normal distribution. "
"1.0 = moderate peak at t=0.5. Higher approaches uniform. " "1.0 = moderate peak at t=0.5. Higher approaches uniform. "
"Only used when timestep_mode=logit_normal.", "Used with logit_normal and curriculum modes.",
}),
"curriculum_switch": ("FLOAT", {
"default": 0.6, "min": 0.1, "max": 0.9, "step": 0.05,
"tooltip": "Fraction of steps to run logit_normal before switching to uniform. "
"0.6 = switch at 60% of total steps. Only used with timestep_mode=curriculum.",
}),
"lora_dropout": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
"0=disabled. 0.050.1 helps regularize on small datasets (arXiv:2404.09610).",
}),
"lora_plus_ratio": ("FLOAT", {
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
}), }),
}, },
} }
@@ -305,7 +321,8 @@ class SelvaLoraTrainer:
def train(self, model, data_dir, output_dir, steps, rank, lr, def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
grad_accum=1, save_every=500, resume_path="", seed=42, grad_accum=1, save_every=500, resume_path="", seed=42,
timestep_mode="uniform", logit_normal_sigma=1.0): timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
lora_dropout=0.0, lora_plus_ratio=1.0):
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
@@ -442,7 +459,8 @@ class SelvaLoraTrainer:
data_dir, output_dir, steps, rank, lr, data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps, alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed, grad_accum, save_every, resume_path, seed,
timestep_mode, logit_normal_sigma, timestep_mode, logit_normal_sigma, curriculum_switch,
lora_dropout, lora_plus_ratio,
) )
def _train_inner( def _train_inner(
@@ -451,19 +469,21 @@ class SelvaLoraTrainer:
data_dir, output_dir, steps, rank, lr, data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps, alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed, grad_accum, save_every, resume_path, seed,
timestep_mode="uniform", logit_normal_sigma=1.0, timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
lora_dropout=0.0, lora_plus_ratio=1.0,
): ):
# --- Prepare generator copy with LoRA --- # --- Prepare generator copy with LoRA ---
generator = copy.deepcopy(model["generator"]).to(device, dtype) generator = copy.deepcopy(model["generator"]).to(device, dtype)
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val, n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
target_suffixes=target_suffixes) target_suffixes=target_suffixes, dropout=lora_dropout)
if n_lora == 0: if n_lora == 0:
raise RuntimeError( raise RuntimeError(
f"[LoRA Trainer] No layers matched target={target_suffixes}. " f"[LoRA Trainer] No layers matched target={target_suffixes}. "
"Check the 'target' field." "Check the 'target' field."
) )
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True) print(f"[LoRA Trainer] Wrapped {n_lora} layers "
f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True)
for name, p in generator.named_parameters(): for name, p in generator.named_parameters():
p.requires_grad_("lora_" in name) p.requires_grad_("lora_" in name)
@@ -475,8 +495,16 @@ class SelvaLoraTrainer:
) )
# --- Optimizer + scheduler --- # --- Optimizer + scheduler ---
lora_params = [p for p in generator.parameters() if p.requires_grad] # LoRA+: split A and B into separate param groups with different LRs.
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2) # ratio=1.0 = standard LoRA (same LR for both). ratio=16 = LoRA+.
lora_A_params = [p for n, p in generator.named_parameters() if "lora_A" in n and p.requires_grad]
lora_B_params = [p for n, p in generator.named_parameters() if "lora_B" in n and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": lora_A_params, "lr": lr},
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
], weight_decay=1e-2)
if lora_plus_ratio != 1.0:
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
def lr_lambda(s): def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0 return s / max(1, warmup_steps) if s < warmup_steps else 1.0
@@ -518,8 +546,15 @@ class SelvaLoraTrainer:
"steps": steps, "steps": steps,
"timestep_mode": timestep_mode, "timestep_mode": timestep_mode,
"logit_normal_sigma": logit_normal_sigma, "logit_normal_sigma": logit_normal_sigma,
"curriculum_switch": curriculum_switch,
"lora_dropout": lora_dropout,
"lora_plus_ratio": lora_plus_ratio,
} }
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
curriculum_switch_step = start_step + int((steps - start_step) * curriculum_switch)
_curriculum_switched = False
print(f"\n[LoRA Trainer] Training {remaining} steps " print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps}, batch_size={batch_size}, " f"(step {start_step + 1}{steps}, batch_size={batch_size}, "
f"timestep_mode={timestep_mode})\n", flush=True) f"timestep_mode={timestep_mode})\n", flush=True)
@@ -538,11 +573,17 @@ class SelvaLoraTrainer:
generator.normalize(x1) generator.normalize(x1)
if timestep_mode == "logit_normal": if timestep_mode == "logit_normal" or (
timestep_mode == "curriculum" and step <= curriculum_switch_step
):
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
t = torch.sigmoid(u) t = torch.sigmoid(u)
else: else:
t = torch.rand(batch_size, device=device, dtype=dtype) t = torch.rand(batch_size, device=device, dtype=dtype)
if timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
print(f"[LoRA Trainer] Curriculum switch: logit_normal → uniform at step {step}", flush=True)
_curriculum_switched = True
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)
@@ -552,7 +593,7 @@ class SelvaLoraTrainer:
running_loss += loss.item() * grad_accum running_loss += loss.item() * grad_accum
if step % grad_accum == 0: if step % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
+11 -5
View File
@@ -25,13 +25,14 @@ import torch.nn as nn
class LoRALinear(nn.Module): class LoRALinear(nn.Module):
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices. """nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank) Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank)
A is initialised with Kaiming uniform; B is initialised to zero so the A is initialised with Kaiming uniform; B is initialised to zero so the
adapter contribution starts at zero and does not disturb pretrained behaviour. adapter contribution starts at zero and does not disturb pretrained behaviour.
Dropout is applied only to the LoRA path, not the base linear.
""" """
def __init__(self, linear: nn.Linear, rank: int, alpha: float): def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0):
super().__init__() super().__init__()
in_f = linear.in_features in_f = linear.in_features
out_f = linear.out_features out_f = linear.out_features
@@ -46,16 +47,18 @@ class LoRALinear(nn.Module):
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device)) self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device)) self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
self.scale = alpha / rank self.scale = alpha / rank
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
def extra_repr(self) -> str: def extra_repr(self) -> str:
rank = self.lora_A.shape[0] rank = self.lora_A.shape[0]
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
return (f"in={self.linear.in_features}, out={self.linear.out_features}, " return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
f"rank={rank}, scale={self.scale:.4f}") f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
def apply_lora( def apply_lora(
@@ -63,6 +66,7 @@ def apply_lora(
rank: int = 16, rank: int = 16,
alpha: float = None, alpha: float = None,
target_suffixes: tuple = ("attn.qkv",), target_suffixes: tuple = ("attn.qkv",),
dropout: float = 0.0,
) -> int: ) -> int:
"""Replace matching nn.Linear layers with LoRALinear in-place. """Replace matching nn.Linear layers with LoRALinear in-place.
@@ -74,6 +78,8 @@ def apply_lora(
("attn.qkv",) which targets all SelfAttention QKV ("attn.qkv",) which targets all SelfAttention QKV
projections in the MM-DiT generator. projections in the MM-DiT generator.
Add "linear1" to also wrap post-attention output projections. Add "linear1" to also wrap post-attention output projections.
dropout: Dropout probability on the LoRA path (not the base linear).
0.050.1 helps regularize on small datasets.
Returns: Returns:
Number of linear layers wrapped. Number of linear layers wrapped.
@@ -92,7 +98,7 @@ def apply_lora(
parent = model parent = model
for part in parts[:-1]: for part in parts[:-1]:
parent = getattr(parent, part) parent = getattr(parent, part)
setattr(parent, parts[-1], LoRALinear(module, rank, alpha)) setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout))
count += 1 count += 1
return count return count
+39 -8
View File
@@ -167,10 +167,16 @@ def main():
help="Path to a step checkpoint (.pt) to resume training from.") help="Path to a step checkpoint (.pt) to resume training from.")
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal"], parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
help="Timestep sampling distribution. uniform matches original MMAudio training. logit_normal reaches lower loss but perceptual improvement is dataset-dependent.") help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
parser.add_argument("--logit_normal_sigma", type=float, default=1.0, parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).") help="Spread of logit-normal distribution.")
parser.add_argument("--curriculum_switch", type=float, default=0.6,
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
parser.add_argument("--lora_dropout", type=float, default=0.0,
help="Dropout on the LoRA path only. 0.050.1 helps on small datasets.")
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
args = parser.parse_args() args = parser.parse_args()
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
@@ -234,8 +240,9 @@ def main():
rank=args.rank, rank=args.rank,
alpha=args.alpha, alpha=args.alpha,
target_suffixes=tuple(args.target), target_suffixes=tuple(args.target),
dropout=args.lora_dropout,
) )
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target})") print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
if n_lora == 0: if n_lora == 0:
print("[LoRA] ERROR: no layers were wrapped — check --target names.") print("[LoRA] ERROR: no layers were wrapped — check --target names.")
sys.exit(1) sys.exit(1)
@@ -315,8 +322,16 @@ def main():
print(f"[LoRA] {len(dataset)} clip(s) ready.") print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler --- # --- Optimizer + LR scheduler ---
lora_params = [p for p in net_generator.parameters() if p.requires_grad] # LoRA+: separate param groups for A and B with different LRs.
optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2) # ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": lora_A_params, "lr": args.lr},
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
], weight_decay=1e-2)
if args.lora_plus_ratio != 1.0:
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
def lr_lambda(step): def lr_lambda(step):
if step < args.warmup_steps: if step < args.warmup_steps:
@@ -351,6 +366,9 @@ def main():
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}") f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n") print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
_curriculum_switched = False
total_loss = 0.0 total_loss = 0.0
for step in range(start_step + 1, args.steps + 1): for step in range(start_step + 1, args.steps + 1):
batch = random.choices(dataset, k=args.batch_size) batch = random.choices(dataset, k=args.batch_size)
@@ -363,11 +381,18 @@ def main():
net_generator.normalize(x1) net_generator.normalize(x1)
if args.timestep_mode == "logit_normal": if args.timestep_mode == "logit_normal" or (
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
):
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
t = torch.sigmoid(u) t = torch.sigmoid(u)
else: else:
t = torch.rand(args.batch_size, device=device, dtype=dtype) t = torch.rand(args.batch_size, device=device, dtype=dtype)
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
_curriculum_switched = True
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)
@@ -378,7 +403,7 @@ def main():
total_loss += loss.item() * args.grad_accum total_loss += loss.item() * args.grad_accum
if step % args.grad_accum == 0: if step % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
@@ -404,6 +429,9 @@ def main():
"steps": args.steps, "steps": args.steps,
"timestep_mode": args.timestep_mode, "timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma, "logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
}, },
}, ckpt_path) }, ckpt_path)
print(f"[LoRA] Saved {ckpt_path}") print(f"[LoRA] Saved {ckpt_path}")
@@ -424,6 +452,9 @@ def main():
"steps": args.steps, "steps": args.steps,
"timestep_mode": args.timestep_mode, "timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma, "logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
} }
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final) torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2)) (output_dir / "meta.json").write_text(json.dumps(meta, indent=2))