fix: save adapter and loss curves on cancel, not only on normal completion
Wraps training loop in try/finally so adapter_final.pt and loss PNGs are always written. On cancellation the adapter is named adapter_cancelled_stepXXXXX.pt so it can be used with --resume. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+26
-10
@@ -507,6 +507,9 @@ class SelvaLoraTrainer:
|
|||||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||||||
f"timestep_mode={timestep_mode})\n", flush=True)
|
f"timestep_mode={timestep_mode})\n", flush=True)
|
||||||
|
|
||||||
|
last_step = start_step
|
||||||
|
completed = False
|
||||||
|
try:
|
||||||
for step in range(start_step + 1, steps + 1):
|
for step in range(start_step + 1, steps + 1):
|
||||||
batch = random.choices(dataset, k=batch_size)
|
batch = random.choices(dataset, k=batch_size)
|
||||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||||
@@ -575,32 +578,45 @@ class SelvaLoraTrainer:
|
|||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||||
|
|
||||||
|
last_step = step
|
||||||
pbar_train.update(1)
|
pbar_train.update(1)
|
||||||
|
|
||||||
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
completed = True
|
||||||
# Increment filename if a previous final already exists (resume case)
|
|
||||||
|
finally:
|
||||||
|
# Save adapter and loss curves whether training completed or was cancelled.
|
||||||
|
# Skip if we never completed a single step (nothing useful to save).
|
||||||
|
if loss_history:
|
||||||
|
if completed:
|
||||||
|
# Normal completion — use adapter_final.pt (increment if exists)
|
||||||
final_path = output_dir / "adapter_final.pt"
|
final_path = output_dir / "adapter_final.pt"
|
||||||
if final_path.exists():
|
if final_path.exists():
|
||||||
i = 1
|
i = 1
|
||||||
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
||||||
i += 1
|
i += 1
|
||||||
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
||||||
|
label = "Done"
|
||||||
|
else:
|
||||||
|
# Cancelled — include the step number so the file is useful for resume
|
||||||
|
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||||||
|
label = f"Cancelled at step {last_step}"
|
||||||
|
|
||||||
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
||||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
|
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||||||
|
|
||||||
# --- Return patched model ---
|
|
||||||
generator.eval()
|
|
||||||
generator.to(next(model["generator"].parameters()).device)
|
|
||||||
patched = {**model, "generator": generator}
|
|
||||||
|
|
||||||
smoothed = _smooth_losses(loss_history)
|
smoothed = _smooth_losses(loss_history)
|
||||||
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||||
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, smoothed=smoothed)
|
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||||||
|
smoothed=smoothed)
|
||||||
raw_img.save(str(output_dir / "loss_raw.png"))
|
raw_img.save(str(output_dir / "loss_raw.png"))
|
||||||
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
||||||
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
||||||
|
|
||||||
loss_curve = _pil_to_tensor(smoothed_img)
|
# Reached only on normal completion (exception re-raises past this point)
|
||||||
|
generator.eval()
|
||||||
|
generator.to(next(model["generator"].parameters()).device)
|
||||||
|
patched = {**model, "generator": generator}
|
||||||
|
|
||||||
|
loss_curve = _pil_to_tensor(smoothed_img)
|
||||||
return (patched, str(final_path), loss_curve)
|
return (patched, str(final_path), loss_curve)
|
||||||
|
|||||||
Reference in New Issue
Block a user