feat: sweep resume + 5 additional experiments (LR, target, extended)
Scheduler: on re-run, reads existing experiment_summary.json and skips already-completed experiments — safe to stop and restart mid-sweep. tier1_thorough: adds g5 (lr 3e-5/3e-4), g6 (full target attn.qkv+linear1 at r16 and r64), and g4_full_r64_6k (6000-step extended run) — 17 total. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -98,6 +98,47 @@
|
|||||||
"lora_plus_ratio": 16.0,
|
"lora_plus_ratio": 16.0,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"timestep_mode": "curriculum"
|
"timestep_mode": "curriculum"
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "g5_lr_low",
|
||||||
|
"group": "lr",
|
||||||
|
"description": "LR=3e-5 — 3× lower than baseline. Tests if 1e-4 is overshooting.",
|
||||||
|
"lr": 3e-5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "g5_lr_high",
|
||||||
|
"group": "lr",
|
||||||
|
"description": "LR=3e-4 — 3× higher than baseline. Tests if 1e-4 is too conservative.",
|
||||||
|
"lr": 3e-4
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "g6_target_full_r16",
|
||||||
|
"group": "target",
|
||||||
|
"description": "Rank 16 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
|
||||||
|
"target": "attn.qkv linear1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "g6_target_full_r64",
|
||||||
|
"group": "target",
|
||||||
|
"description": "Rank 64 + alpha=32 targeting attn.qkv + linear1. Maximum coverage + expressiveness.",
|
||||||
|
"rank": 64,
|
||||||
|
"alpha": 32.0,
|
||||||
|
"target": "attn.qkv linear1"
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "g4_full_r64_6k",
|
||||||
|
"group": "combined",
|
||||||
|
"description": "All Tier 1 at rank 64 + alpha=32, extended to 6000 steps. Checks if convergence is done at 4000.",
|
||||||
|
"rank": 64,
|
||||||
|
"alpha": 32.0,
|
||||||
|
"lora_plus_ratio": 16.0,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"timestep_mode": "curriculum",
|
||||||
|
"steps": 6000,
|
||||||
|
"save_every": 1000
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -296,20 +296,50 @@ class SelvaLoraScheduler:
|
|||||||
dataset = _prepare_dataset(model, data_dir, device)
|
dataset = _prepare_dataset(model, data_dir, device)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 4. Build the summary skeleton (written incrementally)
|
# 4. Build or restore the summary (resume-aware)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
summary = {
|
summary_path = output_root / "experiment_summary.json"
|
||||||
"sweep_name": sweep_name,
|
completed_ids = set()
|
||||||
"description": description,
|
all_curve_data = [] # collected for comparison image
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
if summary_path.exists():
|
||||||
"completed_at": None,
|
try:
|
||||||
"system": _get_system_info(),
|
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||||
"data_dir": str(data_dir),
|
for rec in existing.get("experiments", []):
|
||||||
"n_clips": n_clips,
|
if rec.get("results", {}).get("status") == "completed":
|
||||||
"experiments": [],
|
completed_ids.add(rec["id"])
|
||||||
}
|
lh = rec["results"].get("loss_history", [])
|
||||||
summary_path = output_root / "experiment_summary.json"
|
all_curve_data.append({
|
||||||
|
"id": rec["id"],
|
||||||
|
"loss_history": lh,
|
||||||
|
"log_interval": rec["results"].get("log_interval", 50),
|
||||||
|
"start_step": 0,
|
||||||
|
})
|
||||||
|
# Restore the original summary, clear completed_at so it gets set again
|
||||||
|
summary = existing
|
||||||
|
summary["completed_at"] = None
|
||||||
|
if completed_ids:
|
||||||
|
print(f"[LoRA Scheduler] Resuming — skipping {len(completed_ids)} "
|
||||||
|
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[LoRA Scheduler] Could not read existing summary ({e}) — starting fresh",
|
||||||
|
flush=True)
|
||||||
|
completed_ids = set()
|
||||||
|
all_curve_data = []
|
||||||
|
summary = None
|
||||||
|
|
||||||
|
if not completed_ids:
|
||||||
|
summary = {
|
||||||
|
"sweep_name": sweep_name,
|
||||||
|
"description": description,
|
||||||
|
"sweep_file": str(exp_path),
|
||||||
|
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"completed_at": None,
|
||||||
|
"system": _get_system_info(),
|
||||||
|
"data_dir": str(data_dir),
|
||||||
|
"n_clips": n_clips,
|
||||||
|
"experiments": [],
|
||||||
|
}
|
||||||
|
|
||||||
def _write_summary():
|
def _write_summary():
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
||||||
@@ -319,10 +349,9 @@ class SelvaLoraScheduler:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 5. Run each experiment
|
# 5. Run each experiment
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
trainer = SelvaLoraTrainer()
|
trainer = SelvaLoraTrainer()
|
||||||
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
||||||
all_curve_data = [] # collected for comparison image
|
log_interval = 50 # matches _train_inner
|
||||||
log_interval = 50 # matches _train_inner
|
|
||||||
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
feature_utils_orig = model["feature_utils"]
|
||||||
seq_cfg = model["seq_cfg"]
|
seq_cfg = model["seq_cfg"]
|
||||||
@@ -332,6 +361,11 @@ class SelvaLoraScheduler:
|
|||||||
for exp in spec["experiments"]:
|
for exp in spec["experiments"]:
|
||||||
exp_id = exp["id"]
|
exp_id = exp["id"]
|
||||||
exp_desc = exp.get("description", "")
|
exp_desc = exp.get("description", "")
|
||||||
|
|
||||||
|
if exp_id in completed_ids:
|
||||||
|
print(f"[LoRA Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
|
||||||
|
pbar_outer.update(1)
|
||||||
|
continue
|
||||||
cfg = _merge_config(base_cfg, exp)
|
cfg = _merge_config(base_cfg, exp)
|
||||||
|
|
||||||
# Required training params
|
# Required training params
|
||||||
|
|||||||
Reference in New Issue
Block a user