From 1be07a80d225ab7bd4b8dbfcdb139108797bf34d Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 13:25:01 +0200 Subject: [PATCH] feat: add cosine LR decay schedule to trainer and scheduler - Add lr_schedule param (constant|cosine) to SelvaLoraTrainer - Cosine decays LR from initial value to ~0 after warmup, preventing the oscillation observed at steps 6000-8000 with lr=2e-4 flat - Wire lr_schedule through scheduler _PARAM_DEFAULTS and _train_inner call - Add g5_r128_lr_2e4_cosine and g5_r128_lr_3e4_cosine to r128_sweet_spot sweep Co-Authored-By: Claude Sonnet 4.6 --- experiments/r128_sweet_spot.json | 15 +++++++++++++++ nodes/selva_lora_scheduler.py | 4 ++++ nodes/selva_lora_trainer.py | 27 ++++++++++++++++++++++----- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/experiments/r128_sweet_spot.json b/experiments/r128_sweet_spot.json index ce2c3b4..a26bfbe 100644 --- a/experiments/r128_sweet_spot.json +++ b/experiments/r128_sweet_spot.json @@ -82,6 +82,21 @@ "description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.", "rank": 256, "lr": 3e-4 + }, + + { + "id": "g5_r128_lr_2e4_cosine", + "group": "cosine", + "description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 6000–8000 by decaying LR to ~0 instead of staying flat.", + "lr": 2e-4, + "lr_schedule": "cosine" + }, + { + "id": "g5_r128_lr_3e4_cosine", + "group": "cosine", + "description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.", + "lr": 3e-4, + "lr_schedule": "cosine" } ] diff --git a/nodes/selva_lora_scheduler.py b/nodes/selva_lora_scheduler.py index 2db45c0..570eda3 100644 --- a/nodes/selva_lora_scheduler.py +++ b/nodes/selva_lora_scheduler.py @@ -78,6 +78,7 @@ _PARAM_DEFAULTS = { "curriculum_switch": 0.6, "lora_dropout": 0.0, "lora_plus_ratio": 1.0, + "lr_schedule": "constant", } # Palette for comparison chart: one color per experiment (cycles if > 8) @@ -386,6 +387,7 @@ class SelvaLoraScheduler: curr_switch = float(cfg.get("curriculum_switch", 0.6)) dropout = float(cfg.get("lora_dropout", 0.0)) plus_ratio = float(cfg.get("lora_plus_ratio", 1.0)) + lr_schedule = str(cfg.get("lr_schedule", "constant")) alpha_val = alpha if alpha > 0.0 else float(rank) target_suffixes = tuple(target.strip().split()) @@ -407,6 +409,7 @@ class SelvaLoraScheduler: "timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma, "curriculum_switch": curr_switch, "lora_dropout": dropout, "lora_plus_ratio": plus_ratio, + "lr_schedule": lr_schedule, }, "results": {"status": "running"}, "adapter_path": None, @@ -425,6 +428,7 @@ class SelvaLoraScheduler: alpha_val, target_suffixes, batch_size, warmup, grad_accum, save_every, resume_path, seed, ts_mode, ln_sigma, curr_switch, dropout, plus_ratio, + lr_schedule, ) duration = time.monotonic() - t_start diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 02b89c3..42d8bf0 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -1,5 +1,6 @@ import copy import json +import math import random import traceback from pathlib import Path @@ -528,6 +529,13 @@ class SelvaLoraTrainer: "tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. " "1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).", }), + "lr_schedule": (["constant", "cosine"], { + "default": "constant", + "tooltip": "LR schedule after warmup. " + "constant: flat LR for all steps. " + "cosine: decay from lr to ~0 following a cosine curve — " + "prevents oscillation when LR is slightly too high.", + }), }, } @@ -551,7 +559,7 @@ class SelvaLoraTrainer: alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, grad_accum=1, save_every=500, resume_path="", seed=42, timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, - lora_dropout=0.0, lora_plus_ratio=1.0): + lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant"): torch.manual_seed(seed) random.seed(seed) @@ -601,7 +609,7 @@ class SelvaLoraTrainer: alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, timestep_mode, logit_normal_sigma, curriculum_switch, - lora_dropout, lora_plus_ratio, + lora_dropout, lora_plus_ratio, lr_schedule, ) return (r["patched_model"], r["adapter_path"], r["loss_curve"]) @@ -612,7 +620,7 @@ class SelvaLoraTrainer: alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, - lora_dropout=0.0, lora_plus_ratio=1.0, + lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant", ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) @@ -648,8 +656,16 @@ class SelvaLoraTrainer: if lora_plus_ratio != 1.0: print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True) - def lr_lambda(s): - return s / max(1, warmup_steps) if s < warmup_steps else 1.0 + if lr_schedule == "cosine": + def lr_lambda(s): + if s < warmup_steps: + return s / max(1, warmup_steps) + progress = (s - warmup_steps) / max(1, steps - warmup_steps) + return max(1e-6 / lr, 0.5 * (1.0 + math.cos(math.pi * progress))) + print(f"[LoRA Trainer] LR schedule: cosine decay {lr:.2e} → 0", flush=True) + else: + def lr_lambda(s): + return s / max(1, warmup_steps) if s < warmup_steps else 1.0 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) @@ -701,6 +717,7 @@ class SelvaLoraTrainer: "curriculum_switch": curriculum_switch, "lora_dropout": lora_dropout, "lora_plus_ratio": lora_plus_ratio, + "lr_schedule": lr_schedule, } # For curriculum mode: compute the step at which we switch from logit_normal to uniform