feat: skip_current.flag to cancel experiment and move to next
Create the flag file in the sweep output_root to skip the running experiment at the next log interval (every 50 steps): touch /path/to/experiment/skip_current.flag Scheduler marks it as 'skipped' in the summary and continues. Skipped experiments are NOT resumed on restart (unlike failed ones). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -38,6 +38,7 @@ import folder_paths
|
||||
from .utils import SELVA_CATEGORY, get_device
|
||||
from .selva_lora_trainer import (
|
||||
SelvaLoraTrainer,
|
||||
SkipExperiment,
|
||||
_prepare_dataset,
|
||||
_smooth_losses,
|
||||
_pil_to_tensor,
|
||||
@@ -474,6 +475,18 @@ class SelvaLoraScheduler:
|
||||
"start_step": 0,
|
||||
})
|
||||
|
||||
except SkipExperiment as e:
|
||||
duration = time.monotonic() - t_start
|
||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
|
||||
exp_record["results"] = {
|
||||
"status": "skipped",
|
||||
"error": str(e),
|
||||
"duration_seconds": round(duration, 1),
|
||||
}
|
||||
_write_summary()
|
||||
pbar_outer.update(1)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
duration = time.monotonic() - t_start
|
||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
|
||||
|
||||
@@ -4,6 +4,10 @@ import random
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SkipExperiment(Exception):
|
||||
"""Raised when skip_current.flag is found — signals the scheduler to move to the next experiment."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -751,6 +755,11 @@ class SelvaLoraTrainer:
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_interval == 0:
|
||||
skip_flag = output_dir.parent / "skip_current.flag"
|
||||
if skip_flag.exists():
|
||||
skip_flag.unlink()
|
||||
raise SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment")
|
||||
|
||||
avg = running_loss / log_interval
|
||||
loss_history.append(avg)
|
||||
# grad_norm_count can be 0 when grad_accum > log_interval
|
||||
|
||||
Reference in New Issue
Block a user