feat: live loss curve preview during training

- Send updated loss curve to ComfyUI frontend every 50 steps via
  pbar_train.update_absolute() with a JPEG preview tuple — same
  mechanism as KSampler's denoising previews.
- Fix x-axis step labels for resumed runs (previously always started
  at 0; now correctly shows start_step + offset).
- Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor
  (converts for ComfyUI IMAGE output).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 17:11:38 +02:00
parent 57cd3dd4b4
commit b430953602
+23 -10
View File
@@ -78,16 +78,17 @@ def _load_npz(path: Path) -> dict:
# Loss curve rendering # Loss curve rendering
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor: def _draw_loss_curve(losses: list[float], log_interval: int,
"""Render a loss curve as a [1, H, W, 3] float32 IMAGE tensor for ComfyUI.""" start_step: int = 0) -> Image.Image:
"""Render a loss curve as a PIL Image."""
W, H = 800, 380 W, H = 800, 380
pl, pr, pt, pb = 70, 20, 25, 45 # plot margins pl, pr, pt, pb = 70, 20, 25, 45
img = Image.new("RGB", (W, H), (255, 255, 255)) img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
pw = W - pl - pr # plot area width pw = W - pl - pr
ph = H - pt - pb # plot area height ph = H - pt - pb
if len(losses) >= 2: if len(losses) >= 2:
lo, hi = min(losses), max(losses) lo, hi = min(losses), max(losses)
@@ -111,11 +112,12 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
pts.append((x, y)) pts.append((x, y))
draw.line(pts, fill=(66, 133, 244), width=2) draw.line(pts, fill=(66, 133, 244), width=2)
# x-axis step labels # x-axis step labels — account for start_step so resumed runs are correct
total_steps = n * log_interval first_step = start_step + log_interval
last_step = start_step + n * log_interval
for i in range(5): for i in range(5):
x = pl + int(i * pw / 4) x = pl + int(i * pw / 4)
step = int(i * total_steps / 4) step = int(first_step + i * (last_step - first_step) / 4)
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120)) draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
# Axes # Axes
@@ -123,8 +125,13 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1) draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40)) draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
return img
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
arr = np.array(img).astype(np.float32) / 255.0 arr = np.array(img).astype(np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0) # [1, H, W, 3] return torch.from_numpy(arr).unsqueeze(0)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -382,6 +389,12 @@ class SelvaLoraTrainer:
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True) f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
running_loss = 0.0 running_loss = 0.0
# Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 85)
)
if step % save_every == 0 or step == steps: if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt" ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({ torch.save({
@@ -406,6 +419,6 @@ class SelvaLoraTrainer:
generator.to(next(model["generator"].parameters()).device) generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator} patched = {**model, "generator": generator}
loss_curve = _draw_loss_curve(loss_history, log_interval) loss_curve = _pil_to_tensor(_draw_loss_curve(loss_history, log_interval, start_step))
return (patched, str(final_path), loss_curve) return (patched, str(final_path), loss_curve)