diff --git a/LORA_TRAINING.md b/LORA_TRAINING.md index 411e172..0d87b54 100644 --- a/LORA_TRAINING.md +++ b/LORA_TRAINING.md @@ -109,6 +109,7 @@ The script will: | `--warmup_steps` | `500` | Linear LR warmup steps | | `--grad_accum` | `4` | Gradient accumulation steps (effective batch = grad_accum × 1) | | `--save_every` | `500` | Save a checkpoint every N steps | +| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) | | `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` | | `--seed` | `42` | Random seed | diff --git a/train_lora.py b/train_lora.py index d4b8681..8942725 100644 --- a/train_lora.py +++ b/train_lora.py @@ -161,6 +161,8 @@ def main(): parser.add_argument("--warmup_steps",type=int, default=500) parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps") parser.add_argument("--save_every", type=int, default=500) + parser.add_argument("--resume", default=None, + help="Path to a step checkpoint (.pt) to resume training from.") parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -295,15 +297,33 @@ def main(): scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) + # --- Resume --- + start_step = 0 + if args.resume: + ckpt = torch.load(args.resume, map_location="cpu", weights_only=False) + if "step" not in ckpt: + print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?") + sys.exit(1) + start_step = ckpt["step"] + if start_step >= args.steps: + print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.") + sys.exit(0) + net_generator.load_state_dict(ckpt["state_dict"], strict=False) + optimizer.load_state_dict(ckpt["optimizer"]) + scheduler.load_state_dict(ckpt["scheduler"]) + print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step} → {args.steps})") + # --- Training loop --- net_generator.train() optimizer.zero_grad() - print(f"\n[LoRA] Training: {args.steps} steps, lr={args.lr}, grad_accum={args.grad_accum}") + remaining = args.steps - start_step + print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), " + f"lr={args.lr}, grad_accum={args.grad_accum}") print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n") total_loss = 0.0 - for step in range(1, args.steps + 1): + for step in range(start_step + 1, args.steps + 1): x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) x1 = x1_cpu.to(device, dtype) @@ -336,9 +356,21 @@ def main(): total_loss = 0.0 if step % args.save_every == 0 or step == args.steps: - ckpt = output_dir / f"adapter_step{step:05d}.pt" - torch.save(get_lora_state_dict(net_generator), ckpt) - print(f"[LoRA] Saved {ckpt}") + ckpt_path = output_dir / f"adapter_step{step:05d}.pt" + torch.save({ + "state_dict": get_lora_state_dict(net_generator), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "meta": { + "variant": args.variant, + "rank": args.rank, + "alpha": args.alpha if args.alpha is not None else float(args.rank), + "target": args.target, + "steps": args.steps, + }, + }, ckpt_path) + print(f"[LoRA] Saved {ckpt_path}") # Save final adapter with embedded metadata final = output_dir / "adapter_final.pt"