feat: live loss curve preview during training
- Send updated loss curve to ComfyUI frontend every 50 steps via pbar_train.update_absolute() with a JPEG preview tuple — same mechanism as KSampler's denoising previews. - Fix x-axis step labels for resumed runs (previously always started at 0; now correctly shows start_step + offset). - Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor (converts for ComfyUI IMAGE output). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+23
-10
@@ -78,16 +78,17 @@ def _load_npz(path: Path) -> dict:
|
|||||||
# Loss curve rendering
|
# Loss curve rendering
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
def _draw_loss_curve(losses: list[float], log_interval: int,
|
||||||
"""Render a loss curve as a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
start_step: int = 0) -> Image.Image:
|
||||||
|
"""Render a loss curve as a PIL Image."""
|
||||||
W, H = 800, 380
|
W, H = 800, 380
|
||||||
pl, pr, pt, pb = 70, 20, 25, 45 # plot margins
|
pl, pr, pt, pb = 70, 20, 25, 45
|
||||||
|
|
||||||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
img = Image.new("RGB", (W, H), (255, 255, 255))
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
pw = W - pl - pr # plot area width
|
pw = W - pl - pr
|
||||||
ph = H - pt - pb # plot area height
|
ph = H - pt - pb
|
||||||
|
|
||||||
if len(losses) >= 2:
|
if len(losses) >= 2:
|
||||||
lo, hi = min(losses), max(losses)
|
lo, hi = min(losses), max(losses)
|
||||||
@@ -111,11 +112,12 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
|||||||
pts.append((x, y))
|
pts.append((x, y))
|
||||||
draw.line(pts, fill=(66, 133, 244), width=2)
|
draw.line(pts, fill=(66, 133, 244), width=2)
|
||||||
|
|
||||||
# x-axis step labels
|
# x-axis step labels — account for start_step so resumed runs are correct
|
||||||
total_steps = n * log_interval
|
first_step = start_step + log_interval
|
||||||
|
last_step = start_step + n * log_interval
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
x = pl + int(i * pw / 4)
|
x = pl + int(i * pw / 4)
|
||||||
step = int(i * total_steps / 4)
|
step = int(first_step + i * (last_step - first_step) / 4)
|
||||||
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
|
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
|
||||||
|
|
||||||
# Axes
|
# Axes
|
||||||
@@ -123,8 +125,13 @@ def _draw_loss_curve(losses: list[float], log_interval: int) -> torch.Tensor:
|
|||||||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
|
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
|
||||||
|
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
||||||
arr = np.array(img).astype(np.float32) / 255.0
|
arr = np.array(img).astype(np.float32) / 255.0
|
||||||
return torch.from_numpy(arr).unsqueeze(0) # [1, H, W, 3]
|
return torch.from_numpy(arr).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -382,6 +389,12 @@ class SelvaLoraTrainer:
|
|||||||
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
|
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
|
# Live preview: send updated loss curve to ComfyUI frontend
|
||||||
|
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||||
|
pbar_train.update_absolute(
|
||||||
|
step - start_step, remaining, ("JPEG", preview_img, 85)
|
||||||
|
)
|
||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
if step % save_every == 0 or step == steps:
|
||||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
@@ -406,6 +419,6 @@ class SelvaLoraTrainer:
|
|||||||
generator.to(next(model["generator"].parameters()).device)
|
generator.to(next(model["generator"].parameters()).device)
|
||||||
patched = {**model, "generator": generator}
|
patched = {**model, "generator": generator}
|
||||||
|
|
||||||
loss_curve = _draw_loss_curve(loss_history, log_interval)
|
loss_curve = _pil_to_tensor(_draw_loss_curve(loss_history, log_interval, start_step))
|
||||||
|
|
||||||
return (patched, str(final_path), loss_curve)
|
return (patched, str(final_path), loss_curve)
|
||||||
|
|||||||
Reference in New Issue
Block a user