feat: add resume support to train_lora.py
Step checkpoints now save optimizer state, scheduler state, and step number alongside the LoRA weights. Pass --resume path/to/adapter_stepXXXXX.pt to continue training from that checkpoint. --steps always means total steps, so resuming from 1000 with --steps 2000 trains 1000 more steps. adapter_final.pt format is unchanged (state_dict + meta only) so SelvaLoraLoader remains compatible. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -109,6 +109,7 @@ The script will:
|
|||||||
| `--warmup_steps` | `500` | Linear LR warmup steps |
|
| `--warmup_steps` | `500` | Linear LR warmup steps |
|
||||||
| `--grad_accum` | `4` | Gradient accumulation steps (effective batch = grad_accum × 1) |
|
| `--grad_accum` | `4` | Gradient accumulation steps (effective batch = grad_accum × 1) |
|
||||||
| `--save_every` | `500` | Save a checkpoint every N steps |
|
| `--save_every` | `500` | Save a checkpoint every N steps |
|
||||||
|
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) |
|
||||||
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
||||||
| `--seed` | `42` | Random seed |
|
| `--seed` | `42` | Random seed |
|
||||||
|
|
||||||
|
|||||||
+37
-5
@@ -161,6 +161,8 @@ def main():
|
|||||||
parser.add_argument("--warmup_steps",type=int, default=500)
|
parser.add_argument("--warmup_steps",type=int, default=500)
|
||||||
parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps")
|
parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps")
|
||||||
parser.add_argument("--save_every", type=int, default=500)
|
parser.add_argument("--save_every", type=int, default=500)
|
||||||
|
parser.add_argument("--resume", default=None,
|
||||||
|
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -295,15 +297,33 @@ def main():
|
|||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||||
|
|
||||||
|
# --- Resume ---
|
||||||
|
start_step = 0
|
||||||
|
if args.resume:
|
||||||
|
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
|
||||||
|
if "step" not in ckpt:
|
||||||
|
print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?")
|
||||||
|
sys.exit(1)
|
||||||
|
start_step = ckpt["step"]
|
||||||
|
if start_step >= args.steps:
|
||||||
|
print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.")
|
||||||
|
sys.exit(0)
|
||||||
|
net_generator.load_state_dict(ckpt["state_dict"], strict=False)
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
scheduler.load_state_dict(ckpt["scheduler"])
|
||||||
|
print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step} → {args.steps})")
|
||||||
|
|
||||||
# --- Training loop ---
|
# --- Training loop ---
|
||||||
net_generator.train()
|
net_generator.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
print(f"\n[LoRA] Training: {args.steps} steps, lr={args.lr}, grad_accum={args.grad_accum}")
|
remaining = args.steps - start_step
|
||||||
|
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), "
|
||||||
|
f"lr={args.lr}, grad_accum={args.grad_accum}")
|
||||||
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
for step in range(1, args.steps + 1):
|
for step in range(start_step + 1, args.steps + 1):
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||||
|
|
||||||
x1 = x1_cpu.to(device, dtype)
|
x1 = x1_cpu.to(device, dtype)
|
||||||
@@ -336,9 +356,21 @@ def main():
|
|||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
|
||||||
if step % args.save_every == 0 or step == args.steps:
|
if step % args.save_every == 0 or step == args.steps:
|
||||||
ckpt = output_dir / f"adapter_step{step:05d}.pt"
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
torch.save(get_lora_state_dict(net_generator), ckpt)
|
torch.save({
|
||||||
print(f"[LoRA] Saved {ckpt}")
|
"state_dict": get_lora_state_dict(net_generator),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": scheduler.state_dict(),
|
||||||
|
"step": step,
|
||||||
|
"meta": {
|
||||||
|
"variant": args.variant,
|
||||||
|
"rank": args.rank,
|
||||||
|
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||||
|
"target": args.target,
|
||||||
|
"steps": args.steps,
|
||||||
|
},
|
||||||
|
}, ckpt_path)
|
||||||
|
print(f"[LoRA] Saved {ckpt_path}")
|
||||||
|
|
||||||
# Save final adapter with embedded metadata
|
# Save final adapter with embedded metadata
|
||||||
final = output_dir / "adapter_final.pt"
|
final = output_dir / "adapter_final.pt"
|
||||||
|
|||||||
Reference in New Issue
Block a user