feat: live loss curve preview during training

- Send updated loss curve to ComfyUI frontend every 50 steps via
  pbar_train.update_absolute() with a JPEG preview tuple — same
  mechanism as KSampler's denoising previews.
- Fix x-axis step labels for resumed runs (previously always started
  at 0; now correctly shows start_step + offset).
- Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor
  (converts for ComfyUI IMAGE output).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 17:11:38 +02:00
parent 57cd3dd4b4
commit b430953602
+25 -12
View File
@@ -78,16 +78,17 @@ def _load_npz(path: Path) -> dict:
# Loss curve rendering
# ---------------------------------------------------------------------------
def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
"""Render a loss curve as a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
def _draw_loss_curve(losses: list[float], log_interval: int,
start_step: int = 0) -> Image.Image:
"""Render a loss curve as a PIL Image."""
W, H = 800, 380
pl, pr, pt, pb = 70, 20, 25, 45 # plot margins
pl, pr, pt, pb = 70, 20, 25, 45
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr # plot area width
ph = H - pt - pb # plot area height
pw = W - pl - pr
ph = H - pt - pb
if len(losses) >= 2:
lo, hi = min(losses), max(losses)
@@ -103,7 +104,7 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
# Loss line
n = len(losses)
n = len(losses)
pts = []
for i, v in enumerate(losses):
x = pl + int(i * pw / max(n - 1, 1))
@@ -111,20 +112,26 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
pts.append((x, y))
draw.line(pts, fill=(66, 133, 244), width=2)
# x-axis step labels
total_steps = n * log_interval
# x-axis step labels — account for start_step so resumed runs are correct
first_step = start_step + log_interval
last_step = start_step + n * log_interval
for i in range(5):
x = pl + int(i * pw / 4)
step = int(i * total_steps / 4)
step = int(first_step + i * (last_step - first_step) / 4)
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
return img
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
arr = np.array(img).astype(np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0) # [1, H, W, 3]
return torch.from_numpy(arr).unsqueeze(0)
# ---------------------------------------------------------------------------
@@ -382,6 +389,12 @@ class SelvaLoraTrainer:
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
running_loss = 0.0
# Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 85)
)
if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
@@ -406,6 +419,6 @@ class SelvaLoraTrainer:
generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator}
loss_curve = _draw_loss_curve(loss_history, log_interval)
loss_curve = _pil_to_tensor(_draw_loss_curve(loss_history, log_interval, start_step))
return (patched, str(final_path), loss_curve)