feat: SelVA Skip Experiment node + save partial scalars on skip
- New node: SelVA Skip Experiment — writes skip_current.flag from UI, queue in a second workflow tab while scheduler is running - SkipExperiment now attaches partial loss/grad/spectral data to the exception so the scheduler saves all collected scalars in the summary Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -9,6 +9,7 @@ _NODES = {
|
||||
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
||||
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
|
||||
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
|
||||
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
|
||||
}
|
||||
|
||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||
|
||||
@@ -478,10 +478,17 @@ class SelvaLoraScheduler:
|
||||
except SkipExperiment as e:
|
||||
duration = time.monotonic() - t_start
|
||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
|
||||
partial = getattr(e, "partial", {})
|
||||
lh = partial.get("loss_history", [])
|
||||
smoothed = _smooth_losses(lh) if lh else []
|
||||
exp_record["results"] = {
|
||||
"status": "skipped",
|
||||
"error": str(e),
|
||||
"duration_seconds": round(duration, 1),
|
||||
"status": "skipped",
|
||||
"stopped_at_step": partial.get("stopped_at_step"),
|
||||
"final_loss": round(smoothed[-1], 6) if smoothed else None,
|
||||
"loss_history": [round(v, 6) for v in lh],
|
||||
"grad_norm_history": partial.get("grad_norm_history", []),
|
||||
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
|
||||
"duration_seconds": round(duration, 1),
|
||||
}
|
||||
_write_summary()
|
||||
pbar_outer.update(1)
|
||||
|
||||
@@ -758,7 +758,14 @@ class SelvaLoraTrainer:
|
||||
skip_flag = output_dir.parent / "skip_current.flag"
|
||||
if skip_flag.exists():
|
||||
skip_flag.unlink()
|
||||
raise SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment")
|
||||
exc = SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment")
|
||||
exc.partial = {
|
||||
"loss_history": list(loss_history),
|
||||
"grad_norm_history": list(grad_norm_history),
|
||||
"spectral_metrics": dict(spectral_metrics),
|
||||
"stopped_at_step": step,
|
||||
}
|
||||
raise exc
|
||||
|
||||
avg = running_loss / log_interval
|
||||
loss_history.append(avg)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from pathlib import Path
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY
|
||||
|
||||
|
||||
class SelvaSkipExperiment:
|
||||
"""Writes skip_current.flag into a sweep output_root.
|
||||
|
||||
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
|
||||
the current experiment and move to the next one. The trainer picks up
|
||||
the flag within 50 steps (~a few seconds).
|
||||
"""
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"output_root": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("flag_path",)
|
||||
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
|
||||
FUNCTION = "skip"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = (
|
||||
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
|
||||
"and move to the next one. Queue this node while the scheduler is running. "
|
||||
"Partial scalars collected so far are saved in the summary."
|
||||
)
|
||||
|
||||
def skip(self, output_root: str):
|
||||
p = Path(output_root.strip())
|
||||
if not p.is_absolute():
|
||||
p = Path(folder_paths.get_output_directory()) / p
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
|
||||
|
||||
flag = p / "skip_current.flag"
|
||||
flag.touch()
|
||||
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
|
||||
return (str(flag),)
|
||||
Reference in New Issue
Block a user