feat: SelVA Skip Experiment node + save partial scalars on skip

- New node: SelVA Skip Experiment — writes skip_current.flag from UI,
  queue in a second workflow tab while scheduler is running
- SkipExperiment now attaches partial loss/grad/spectral data to the
  exception so the scheduler saves all collected scalars in the summary

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 13:10:43 +02:00
parent 264dc49d42
commit 58e1985af2
4 changed files with 69 additions and 4 deletions
+1
View File
@@ -9,6 +9,7 @@ _NODES = {
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"), "SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"), "SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"), "SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
} }
for key, (module_path, class_name, display_name) in _NODES.items(): for key, (module_path, class_name, display_name) in _NODES.items():
+8 -1
View File
@@ -478,9 +478,16 @@ class SelvaLoraScheduler:
except SkipExperiment as e: except SkipExperiment as e:
duration = time.monotonic() - t_start duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True) print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
partial = getattr(e, "partial", {})
lh = partial.get("loss_history", [])
smoothed = _smooth_losses(lh) if lh else []
exp_record["results"] = { exp_record["results"] = {
"status": "skipped", "status": "skipped",
"error": str(e), "stopped_at_step": partial.get("stopped_at_step"),
"final_loss": round(smoothed[-1], 6) if smoothed else None,
"loss_history": [round(v, 6) for v in lh],
"grad_norm_history": partial.get("grad_norm_history", []),
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
"duration_seconds": round(duration, 1), "duration_seconds": round(duration, 1),
} }
_write_summary() _write_summary()
+8 -1
View File
@@ -758,7 +758,14 @@ class SelvaLoraTrainer:
skip_flag = output_dir.parent / "skip_current.flag" skip_flag = output_dir.parent / "skip_current.flag"
if skip_flag.exists(): if skip_flag.exists():
skip_flag.unlink() skip_flag.unlink()
raise SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment") exc = SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment")
exc.partial = {
"loss_history": list(loss_history),
"grad_norm_history": list(grad_norm_history),
"spectral_metrics": dict(spectral_metrics),
"stopped_at_step": step,
}
raise exc
avg = running_loss / log_interval avg = running_loss / log_interval
loss_history.append(avg) loss_history.append(avg)
+50
View File
@@ -0,0 +1,50 @@
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaSkipExperiment:
"""Writes skip_current.flag into a sweep output_root.
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
the current experiment and move to the next one. The trainer picks up
the flag within 50 steps (~a few seconds).
"""
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"output_root": ("STRING", {
"default": "",
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("flag_path",)
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
FUNCTION = "skip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
"and move to the next one. Queue this node while the scheduler is running. "
"Partial scalars collected so far are saved in the summary."
)
def skip(self, output_root: str):
p = Path(output_root.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
flag = p / "skip_current.flag"
flag.touch()
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
return (str(flag),)