From 3d9221c24864a40cf3ee7ea874b5eeae6d71a6b5 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Mon, 6 Apr 2026 13:11:25 +0200 Subject: [PATCH] fix: three bugs in scheduler and trainer - trainer: raise ValueError early when remaining steps < log_interval (50) instead of UnboundLocalError on smoothed_img/final_path at return - trainer: use None in grad_norm_history instead of silent 0.0 when grad_accum > log_interval and no optimizer step fired in the interval - trainer: include start_step in _train_inner return dict - scheduler: use start_step from result dict for min_loss_step and loss_at_steps (fixes wrong step labels on resumed experiments) Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_scheduler.py | 8 ++++++-- nodes/selva_lora_trainer.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/nodes/selva_lora_scheduler.py b/nodes/selva_lora_scheduler.py index 7576148..514f1c9 100644 --- a/nodes/selva_lora_scheduler.py +++ b/nodes/selva_lora_scheduler.py @@ -395,13 +395,17 @@ class SelvaLoraScheduler: duration = time.monotonic() - t_start loss_history = r["loss_history"] grad_norm_history = r.get("grad_norm_history", []) + run_start_step = r.get("start_step", 0) smoothed = _smooth_losses(loss_history) if loss_history else [] # Scalar summary metrics final_loss = round(smoothed[-1], 6) if smoothed else None min_loss = round(min(smoothed), 6) if smoothed else None min_idx = smoothed.index(min(smoothed)) if smoothed else None - min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None + min_loss_step = ( + run_start_step + (min_idx + 1) * log_interval + if min_idx is not None else None + ) # Stability: std-dev of raw loss over last 25% of steps if loss_history: @@ -418,7 +422,7 @@ class SelvaLoraScheduler: "min_loss_step": min_loss_step, "loss_std_last_quarter": loss_std_last_quarter, "loss_at_steps": _loss_at_steps( - loss_history, log_interval, save_every, 0, steps + loss_history, log_interval, save_every, run_start_step, steps ), "loss_history": [round(v, 6) for v in loss_history], "grad_norm_history": grad_norm_history, diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 7ad619b..d1b9911 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -549,6 +549,12 @@ class SelvaLoraTrainer: log_interval = 50 remaining = steps - start_step + if remaining < log_interval: + raise ValueError( + f"[LoRA Trainer] Only {remaining} steps remaining (steps={steps}, " + f"start_step={start_step}). Need at least {log_interval} steps to " + "record any loss — increase 'steps' or lower the resume checkpoint." + ) pbar_train = comfy.utils.ProgressBar(remaining) loss_history = [] running_loss = 0.0 @@ -621,13 +627,20 @@ class SelvaLoraTrainer: optimizer.zero_grad() if step % log_interval == 0: - avg = running_loss / log_interval - avg_gnorm = running_grad_norm / max(1, grad_norm_count) + avg = running_loss / log_interval loss_history.append(avg) - grad_norm_history.append(round(avg_gnorm, 6)) + # grad_norm_count can be 0 when grad_accum > log_interval + # (no optimizer step fired in this interval yet) + if grad_norm_count > 0: + avg_gnorm = running_grad_norm / grad_norm_count + grad_norm_history.append(round(avg_gnorm, 6)) + gnorm_str = f" grad_norm={avg_gnorm:.4f}" + else: + grad_norm_history.append(None) + gnorm_str = "" lr_now = scheduler.get_last_lr()[0] print(f"[LoRA Trainer] step {step:5d}/{steps} " - f"loss={avg:.4f} grad_norm={avg_gnorm:.4f} " + f"loss={avg:.4f}{gnorm_str} " f"lr={lr_now:.2e} bs={batch_size}", flush=True) running_loss = 0.0 running_grad_norm = 0.0 @@ -705,11 +718,12 @@ class SelvaLoraTrainer: loss_curve = _pil_to_tensor(smoothed_img) return { - "patched_model": patched, - "adapter_path": str(final_path), - "loss_curve": loss_curve, - "loss_history": loss_history, + "patched_model": patched, + "adapter_path": str(final_path), + "loss_curve": loss_curve, + "loss_history": loss_history, "grad_norm_history": grad_norm_history, - "meta": meta, - "completed": True, + "start_step": start_step, + "meta": meta, + "completed": True, }