feat: add cosine LR decay schedule to trainer and scheduler
- Add lr_schedule param (constant|cosine) to SelvaLoraTrainer - Cosine decays LR from initial value to ~0 after warmup, preventing the oscillation observed at steps 6000-8000 with lr=2e-4 flat - Wire lr_schedule through scheduler _PARAM_DEFAULTS and _train_inner call - Add g5_r128_lr_2e4_cosine and g5_r128_lr_3e4_cosine to r128_sweet_spot sweep Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -82,6 +82,21 @@
|
|||||||
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
|
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
|
||||||
"rank": 256,
|
"rank": 256,
|
||||||
"lr": 3e-4
|
"lr": 3e-4
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "g5_r128_lr_2e4_cosine",
|
||||||
|
"group": "cosine",
|
||||||
|
"description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 6000–8000 by decaying LR to ~0 instead of staying flat.",
|
||||||
|
"lr": 2e-4,
|
||||||
|
"lr_schedule": "cosine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "g5_r128_lr_3e4_cosine",
|
||||||
|
"group": "cosine",
|
||||||
|
"description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.",
|
||||||
|
"lr": 3e-4,
|
||||||
|
"lr_schedule": "cosine"
|
||||||
}
|
}
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ _PARAM_DEFAULTS = {
|
|||||||
"curriculum_switch": 0.6,
|
"curriculum_switch": 0.6,
|
||||||
"lora_dropout": 0.0,
|
"lora_dropout": 0.0,
|
||||||
"lora_plus_ratio": 1.0,
|
"lora_plus_ratio": 1.0,
|
||||||
|
"lr_schedule": "constant",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Palette for comparison chart: one color per experiment (cycles if > 8)
|
# Palette for comparison chart: one color per experiment (cycles if > 8)
|
||||||
@@ -386,6 +387,7 @@ class SelvaLoraScheduler:
|
|||||||
curr_switch = float(cfg.get("curriculum_switch", 0.6))
|
curr_switch = float(cfg.get("curriculum_switch", 0.6))
|
||||||
dropout = float(cfg.get("lora_dropout", 0.0))
|
dropout = float(cfg.get("lora_dropout", 0.0))
|
||||||
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
|
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
|
||||||
|
lr_schedule = str(cfg.get("lr_schedule", "constant"))
|
||||||
alpha_val = alpha if alpha > 0.0 else float(rank)
|
alpha_val = alpha if alpha > 0.0 else float(rank)
|
||||||
target_suffixes = tuple(target.strip().split())
|
target_suffixes = tuple(target.strip().split())
|
||||||
|
|
||||||
@@ -407,6 +409,7 @@ class SelvaLoraScheduler:
|
|||||||
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
|
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
|
||||||
"curriculum_switch": curr_switch,
|
"curriculum_switch": curr_switch,
|
||||||
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
|
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
|
||||||
|
"lr_schedule": lr_schedule,
|
||||||
},
|
},
|
||||||
"results": {"status": "running"},
|
"results": {"status": "running"},
|
||||||
"adapter_path": None,
|
"adapter_path": None,
|
||||||
@@ -425,6 +428,7 @@ class SelvaLoraScheduler:
|
|||||||
alpha_val, target_suffixes, batch_size, warmup,
|
alpha_val, target_suffixes, batch_size, warmup,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
|
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
|
||||||
|
lr_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
duration = time.monotonic() - t_start
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -528,6 +529,13 @@ class SelvaLoraTrainer:
|
|||||||
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
||||||
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
||||||
}),
|
}),
|
||||||
|
"lr_schedule": (["constant", "cosine"], {
|
||||||
|
"default": "constant",
|
||||||
|
"tooltip": "LR schedule after warmup. "
|
||||||
|
"constant: flat LR for all steps. "
|
||||||
|
"cosine: decay from lr to ~0 following a cosine curve — "
|
||||||
|
"prevents oscillation when LR is slightly too high.",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -551,7 +559,7 @@ class SelvaLoraTrainer:
|
|||||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
lora_dropout=0.0, lora_plus_ratio=1.0):
|
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant"):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -601,7 +609,7 @@ class SelvaLoraTrainer:
|
|||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||||
lora_dropout, lora_plus_ratio,
|
lora_dropout, lora_plus_ratio, lr_schedule,
|
||||||
)
|
)
|
||||||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||||
|
|
||||||
@@ -612,7 +620,7 @@ class SelvaLoraTrainer:
|
|||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||||
lora_dropout=0.0, lora_plus_ratio=1.0,
|
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
@@ -648,8 +656,16 @@ class SelvaLoraTrainer:
|
|||||||
if lora_plus_ratio != 1.0:
|
if lora_plus_ratio != 1.0:
|
||||||
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
||||||
|
|
||||||
def lr_lambda(s):
|
if lr_schedule == "cosine":
|
||||||
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
def lr_lambda(s):
|
||||||
|
if s < warmup_steps:
|
||||||
|
return s / max(1, warmup_steps)
|
||||||
|
progress = (s - warmup_steps) / max(1, steps - warmup_steps)
|
||||||
|
return max(1e-6 / lr, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||||
|
print(f"[LoRA Trainer] LR schedule: cosine decay {lr:.2e} → 0", flush=True)
|
||||||
|
else:
|
||||||
|
def lr_lambda(s):
|
||||||
|
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||||
@@ -701,6 +717,7 @@ class SelvaLoraTrainer:
|
|||||||
"curriculum_switch": curriculum_switch,
|
"curriculum_switch": curriculum_switch,
|
||||||
"lora_dropout": lora_dropout,
|
"lora_dropout": lora_dropout,
|
||||||
"lora_plus_ratio": lora_plus_ratio,
|
"lora_plus_ratio": lora_plus_ratio,
|
||||||
|
"lr_schedule": lr_schedule,
|
||||||
}
|
}
|
||||||
|
|
||||||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||||
|
|||||||
Reference in New Issue
Block a user