diff --git a/nodes/selva_lora_scheduler.py b/nodes/selva_lora_scheduler.py index 03270e7..0fe2dcf 100644 --- a/nodes/selva_lora_scheduler.py +++ b/nodes/selva_lora_scheduler.py @@ -38,6 +38,7 @@ import folder_paths from .utils import SELVA_CATEGORY, get_device from .selva_lora_trainer import ( SelvaLoraTrainer, + SkipExperiment, _prepare_dataset, _smooth_losses, _pil_to_tensor, @@ -474,6 +475,18 @@ class SelvaLoraScheduler: "start_step": 0, }) + except SkipExperiment as e: + duration = time.monotonic() - t_start + print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True) + exp_record["results"] = { + "status": "skipped", + "error": str(e), + "duration_seconds": round(duration, 1), + } + _write_summary() + pbar_outer.update(1) + continue + except Exception as e: duration = time.monotonic() - t_start print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 959fbe0..7cffe47 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -4,6 +4,10 @@ import random import traceback from pathlib import Path + +class SkipExperiment(Exception): + """Raised when skip_current.flag is found — signals the scheduler to move to the next experiment.""" + import numpy as np import torch import torch.nn.functional as F @@ -751,6 +755,11 @@ class SelvaLoraTrainer: optimizer.zero_grad() if step % log_interval == 0: + skip_flag = output_dir.parent / "skip_current.flag" + if skip_flag.exists(): + skip_flag.unlink() + raise SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment") + avg = running_loss / log_interval loss_history.append(avg) # grad_norm_count can be 0 when grad_accum > log_interval