fix: save adapter and loss curves on cancel, not only on normal completion
Wraps training loop in try/finally so adapter_final.pt and loss PNGs are always written. On cancellation the adapter is named adapter_cancelled_stepXXXXX.pt so it can be used with --resume. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+91
-75
@@ -489,96 +489,112 @@ class SelvaLoraTrainer:
|
|||||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True)
|
f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True)
|
||||||
|
|
||||||
for step in range(start_step + 1, steps + 1):
|
last_step = start_step
|
||||||
batch = random.choices(dataset, k=batch_size)
|
completed = False
|
||||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
try:
|
||||||
|
for step in range(start_step + 1, steps + 1):
|
||||||
|
batch = random.choices(dataset, k=batch_size)
|
||||||
|
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||||
|
|
||||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||||||
|
|
||||||
generator.normalize(x1)
|
generator.normalize(x1)
|
||||||
|
|
||||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||||||
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
||||||
loss.backward()
|
loss.backward()
|
||||||
running_loss += loss.item() * grad_accum
|
running_loss += loss.item() * grad_accum
|
||||||
|
|
||||||
if step % grad_accum == 0:
|
if step % grad_accum == 0:
|
||||||
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if step % log_interval == 0:
|
if step % log_interval == 0:
|
||||||
avg = running_loss / log_interval
|
avg = running_loss / log_interval
|
||||||
loss_history.append(avg)
|
loss_history.append(avg)
|
||||||
lr_now = scheduler.get_last_lr()[0]
|
lr_now = scheduler.get_last_lr()[0]
|
||||||
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
||||||
f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True)
|
f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True)
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
# Live preview: send updated loss curve to ComfyUI frontend
|
# Live preview: send updated loss curve to ComfyUI frontend
|
||||||
preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||||||
smoothed=_smooth_losses(loss_history))
|
smoothed=_smooth_losses(loss_history))
|
||||||
pbar_train.update_absolute(
|
pbar_train.update_absolute(
|
||||||
step - start_step, remaining, ("JPEG", preview_img, 800)
|
step - start_step, remaining, ("JPEG", preview_img, 800)
|
||||||
)
|
)
|
||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
if step % save_every == 0 or step == steps:
|
||||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
"state_dict": get_lora_state_dict(generator),
|
"state_dict": get_lora_state_dict(generator),
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"scheduler": scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
"step": step,
|
"step": step,
|
||||||
"meta": meta,
|
"meta": meta,
|
||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||||||
|
|
||||||
# Save a quick eval sample next to the checkpoint
|
# Save a quick eval sample next to the checkpoint
|
||||||
wav, sr = _eval_sample(generator, feature_utils_orig,
|
wav, sr = _eval_sample(generator, feature_utils_orig,
|
||||||
dataset, seq_cfg, device, dtype)
|
dataset, seq_cfg, device, dtype)
|
||||||
if wav is not None:
|
if wav is not None:
|
||||||
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
||||||
try:
|
try:
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
torchaudio.save(str(wav_path), wav, sr)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||||
|
|
||||||
pbar_train.update(1)
|
last_step = step
|
||||||
|
pbar_train.update(1)
|
||||||
|
|
||||||
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
completed = True
|
||||||
# Increment filename if a previous final already exists (resume case)
|
|
||||||
final_path = output_dir / "adapter_final.pt"
|
|
||||||
if final_path.exists():
|
|
||||||
i = 1
|
|
||||||
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
|
||||||
i += 1
|
|
||||||
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
|
||||||
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
|
||||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
|
||||||
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
|
|
||||||
|
|
||||||
# --- Return patched model ---
|
finally:
|
||||||
|
# Save adapter and loss curves whether training completed or was cancelled.
|
||||||
|
# Skip if we never completed a single step (nothing useful to save).
|
||||||
|
if loss_history:
|
||||||
|
if completed:
|
||||||
|
# Normal completion — use adapter_final.pt (increment if exists)
|
||||||
|
final_path = output_dir / "adapter_final.pt"
|
||||||
|
if final_path.exists():
|
||||||
|
i = 1
|
||||||
|
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
||||||
|
i += 1
|
||||||
|
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
||||||
|
label = "Done"
|
||||||
|
else:
|
||||||
|
# Cancelled — include the step number so the file is useful for resume
|
||||||
|
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||||||
|
label = f"Cancelled at step {last_step}"
|
||||||
|
|
||||||
|
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
||||||
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
|
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||||||
|
|
||||||
|
smoothed = _smooth_losses(loss_history)
|
||||||
|
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||||
|
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||||||
|
smoothed=smoothed)
|
||||||
|
raw_img.save(str(output_dir / "loss_raw.png"))
|
||||||
|
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
||||||
|
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
||||||
|
|
||||||
|
# Reached only on normal completion (exception re-raises past this point)
|
||||||
generator.eval()
|
generator.eval()
|
||||||
generator.to(next(model["generator"].parameters()).device)
|
generator.to(next(model["generator"].parameters()).device)
|
||||||
patched = {**model, "generator": generator}
|
patched = {**model, "generator": generator}
|
||||||
|
|
||||||
smoothed = _smooth_losses(loss_history)
|
|
||||||
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
|
||||||
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, smoothed=smoothed)
|
|
||||||
raw_img.save(str(output_dir / "loss_raw.png"))
|
|
||||||
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
|
||||||
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
|
||||||
|
|
||||||
loss_curve = _pil_to_tensor(smoothed_img)
|
loss_curve = _pil_to_tensor(smoothed_img)
|
||||||
|
|
||||||
return (patched, str(final_path), loss_curve)
|
return (patched, str(final_path), loss_curve)
|
||||||
|
|||||||
Reference in New Issue
Block a user