feat: live loss curve preview during training
- Send updated loss curve to ComfyUI frontend every 50 steps via pbar_train.update_absolute() with a JPEG preview tuple — same mechanism as KSampler's denoising previews. - Fix x-axis step labels for resumed runs (previously always started at 0; now correctly shows start_step + offset). - Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor (converts for ComfyUI IMAGE output). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+23
-10
@@ -78,16 +78,17 @@ def _load_npz(path: Path) -> dict:
|
||||
# Loss curve rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
||||
"""Render a loss curve as a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
||||
def _draw_loss_curve(losses: list[float], log_interval: int,
|
||||
start_step: int = 0) -> Image.Image:
|
||||
"""Render a loss curve as a PIL Image."""
|
||||
W, H = 800, 380
|
||||
pl, pr, pt, pb = 70, 20, 25, 45 # plot margins
|
||||
pl, pr, pt, pb = 70, 20, 25, 45
|
||||
|
||||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
pw = W - pl - pr # plot area width
|
||||
ph = H - pt - pb # plot area height
|
||||
pw = W - pl - pr
|
||||
ph = H - pt - pb
|
||||
|
||||
if len(losses) >= 2:
|
||||
lo, hi = min(losses), max(losses)
|
||||
@@ -111,11 +112,12 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
||||
pts.append((x, y))
|
||||
draw.line(pts, fill=(66, 133, 244), width=2)
|
||||
|
||||
# x-axis step labels
|
||||
total_steps = n * log_interval
|
||||
# x-axis step labels — account for start_step so resumed runs are correct
|
||||
first_step = start_step + log_interval
|
||||
last_step = start_step + n * log_interval
|
||||
for i in range(5):
|
||||
x = pl + int(i * pw / 4)
|
||||
step = int(i * total_steps / 4)
|
||||
step = int(first_step + i * (last_step - first_step) / 4)
|
||||
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
|
||||
|
||||
# Axes
|
||||
@@ -123,8 +125,13 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
||||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
||||
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
|
||||
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
||||
arr = np.array(img).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0) # [1, H, W, 3]
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -382,6 +389,12 @@ class SelvaLoraTrainer:
|
||||
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
|
||||
running_loss = 0.0
|
||||
|
||||
# Live preview: send updated loss curve to ComfyUI frontend
|
||||
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||
pbar_train.update_absolute(
|
||||
step - start_step, remaining, ("JPEG", preview_img, 85)
|
||||
)
|
||||
|
||||
if step % save_every == 0 or step == steps:
|
||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||
torch.save({
|
||||
@@ -406,6 +419,6 @@ class SelvaLoraTrainer:
|
||||
generator.to(next(model["generator"].parameters()).device)
|
||||
patched = {**model, "generator": generator}
|
||||
|
||||
loss_curve = _draw_loss_curve(loss_history, log_interval)
|
||||
loss_curve = _pil_to_tensor(_draw_loss_curve(loss_history, log_interval, start_step))
|
||||
|
||||
return (patched, str(final_path), loss_curve)
|
||||
|
||||
Reference in New Issue
Block a user