feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights), 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610) - LoRA+: separate optimizer param groups for lora_A and lora_B with configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354) - Curriculum mode: logit_normal for first N% of steps then uniform, directly addresses early convergence + fine-detail degradation at boundaries (arXiv:2603.12517) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+39
-8
@@ -167,10 +167,16 @@ def main():
|
||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal"],
|
||||
help="Timestep sampling distribution. uniform matches original MMAudio training. logit_normal reaches lower loss but perceptual improvement is dataset-dependent.")
|
||||
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
|
||||
help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
|
||||
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
||||
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).")
|
||||
help="Spread of logit-normal distribution.")
|
||||
parser.add_argument("--curriculum_switch", type=float, default=0.6,
|
||||
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0,
|
||||
help="Dropout on the LoRA path only. 0.05–0.1 helps on small datasets.")
|
||||
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
|
||||
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@@ -234,8 +240,9 @@ def main():
|
||||
rank=args.rank,
|
||||
alpha=args.alpha,
|
||||
target_suffixes=tuple(args.target),
|
||||
dropout=args.lora_dropout,
|
||||
)
|
||||
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target})")
|
||||
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
|
||||
if n_lora == 0:
|
||||
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
|
||||
sys.exit(1)
|
||||
@@ -315,8 +322,16 @@ def main():
|
||||
print(f"[LoRA] {len(dataset)} clip(s) ready.")
|
||||
|
||||
# --- Optimizer + LR scheduler ---
|
||||
lora_params = [p for p in net_generator.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2)
|
||||
# LoRA+: separate param groups for A and B with different LRs.
|
||||
# ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
|
||||
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
|
||||
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": lora_A_params, "lr": args.lr},
|
||||
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
|
||||
], weight_decay=1e-2)
|
||||
if args.lora_plus_ratio != 1.0:
|
||||
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
|
||||
|
||||
def lr_lambda(step):
|
||||
if step < args.warmup_steps:
|
||||
@@ -351,6 +366,9 @@ def main():
|
||||
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
|
||||
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
||||
|
||||
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
|
||||
_curriculum_switched = False
|
||||
|
||||
total_loss = 0.0
|
||||
for step in range(start_step + 1, args.steps + 1):
|
||||
batch = random.choices(dataset, k=args.batch_size)
|
||||
@@ -363,11 +381,18 @@ def main():
|
||||
|
||||
net_generator.normalize(x1)
|
||||
|
||||
if args.timestep_mode == "logit_normal":
|
||||
if args.timestep_mode == "logit_normal" or (
|
||||
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||||
):
|
||||
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
|
||||
t = torch.sigmoid(u)
|
||||
else:
|
||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||
|
||||
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
|
||||
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
|
||||
_curriculum_switched = True
|
||||
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
@@ -378,7 +403,7 @@ def main():
|
||||
total_loss += loss.item() * args.grad_accum
|
||||
|
||||
if step % args.grad_accum == 0:
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
||||
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@@ -404,6 +429,9 @@ def main():
|
||||
"steps": args.steps,
|
||||
"timestep_mode": args.timestep_mode,
|
||||
"logit_normal_sigma": args.logit_normal_sigma,
|
||||
"curriculum_switch": args.curriculum_switch,
|
||||
"lora_dropout": args.lora_dropout,
|
||||
"lora_plus_ratio": args.lora_plus_ratio,
|
||||
},
|
||||
}, ckpt_path)
|
||||
print(f"[LoRA] Saved {ckpt_path}")
|
||||
@@ -424,6 +452,9 @@ def main():
|
||||
"steps": args.steps,
|
||||
"timestep_mode": args.timestep_mode,
|
||||
"logit_normal_sigma": args.logit_normal_sigma,
|
||||
"curriculum_switch": args.curriculum_switch,
|
||||
"lora_dropout": args.lora_dropout,
|
||||
"lora_plus_ratio": args.lora_plus_ratio,
|
||||
}
|
||||
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||
|
||||
Reference in New Issue
Block a user