From b430953602f745b937e4f5ff582d8202ec2f5e8f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 17:11:38 +0200 Subject: [PATCH] feat: live loss curve preview during training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Send updated loss curve to ComfyUI frontend every 50 steps via pbar_train.update_absolute() with a JPEG preview tuple — same mechanism as KSampler's denoising previews. - Fix x-axis step labels for resumed runs (previously always started at 0; now correctly shows start_step + offset). - Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor (converts for ComfyUI IMAGE output). Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_trainer.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 2f19249..866fc3b 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -78,16 +78,17 @@ def _load_npz(path: Path) -> dict: # Loss curve rendering # --------------------------------------------------------------------------- -def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor: - """Render a loss curve as a [1, H, W, 3] float32 IMAGE tensor for ComfyUI.""" +def _draw_loss_curve(losses: list[float], log_interval: int, + start_step: int = 0) -> Image.Image: + """Render a loss curve as a PIL Image.""" W, H = 800, 380 - pl, pr, pt, pb = 70, 20, 25, 45 # plot margins + pl, pr, pt, pb = 70, 20, 25, 45 img = Image.new("RGB", (W, H), (255, 255, 255)) draw = ImageDraw.Draw(img) - pw = W - pl - pr # plot area width - ph = H - pt - pb # plot area height + pw = W - pl - pr + ph = H - pt - pb if len(losses) >= 2: lo, hi = min(losses), max(losses) @@ -103,7 +104,7 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor: draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120)) # Loss line - n = len(losses) + n = len(losses) pts = [] for i, v in enumerate(losses): x = pl + int(i * pw / max(n - 1, 1)) @@ -111,20 +112,26 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor: pts.append((x, y)) draw.line(pts, fill=(66, 133, 244), width=2) - # x-axis step labels - total_steps = n * log_interval + # x-axis step labels — account for start_step so resumed runs are correct + first_step = start_step + log_interval + last_step = start_step + n * log_interval for i in range(5): x = pl + int(i * pw / 4) - step = int(i * total_steps / 4) + step = int(first_step + i * (last_step - first_step) / 4) draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120)) # Axes - draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1) + draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1) draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1) draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40)) + return img + + +def _pil_to_tensor(img: Image.Image) -> torch.Tensor: + """Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI.""" arr = np.array(img).astype(np.float32) / 255.0 - return torch.from_numpy(arr).unsqueeze(0) # [1, H, W, 3] + return torch.from_numpy(arr).unsqueeze(0) # --------------------------------------------------------------------------- @@ -382,6 +389,12 @@ class SelvaLoraTrainer: f"loss={avg:.4f} lr={lr_now:.2e}", flush=True) running_loss = 0.0 + # Live preview: send updated loss curve to ComfyUI frontend + preview_img = _draw_loss_curve(loss_history, log_interval, start_step) + pbar_train.update_absolute( + step - start_step, remaining, ("JPEG", preview_img, 85) + ) + if step % save_every == 0 or step == steps: ckpt_path = output_dir / f"adapter_step{step:05d}.pt" torch.save({ @@ -406,6 +419,6 @@ class SelvaLoraTrainer: generator.to(next(model["generator"].parameters()).device) patched = {**model, "generator": generator} - loss_curve = _draw_loss_curve(loss_history, log_interval) + loss_curve = _pil_to_tensor(_draw_loss_curve(loss_history, log_interval, start_step)) return (patched, str(final_path), loss_curve)