fix: save adapter and loss curves on cancel, not only on normal completion

Wraps training loop in try/finally so adapter_final.pt and loss PNGs are
always written. On cancellation the adapter is named
adapter_cancelled_stepXXXXX.pt so it can be used with --resume.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 01:06:44 +02:00
parent 8338560600
commit c86258d48f
+26 -10
View File
@@ -489,6 +489,9 @@ class SelvaLoraTrainer:
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps}, batch_size={batch_size})\n", flush=True)
last_step = start_step
completed = False
try:
for step in range(start_step + 1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
@@ -553,32 +556,45 @@ class SelvaLoraTrainer:
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
last_step = step
pbar_train.update(1)
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
# Increment filename if a previous final already exists (resume case)
completed = True
finally:
# Save adapter and loss curves whether training completed or was cancelled.
# Skip if we never completed a single step (nothing useful to save).
if loss_history:
if completed:
# Normal completion — use adapter_final.pt (increment if exists)
final_path = output_dir / "adapter_final.pt"
if final_path.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final_path = output_dir / f"adapter_final_{i:03d}.pt"
label = "Done"
else:
# Cancelled — include the step number so the file is useful for resume
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
label = f"Cancelled at step {last_step}"
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
# --- Return patched model ---
generator.eval()
generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator}
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
smoothed = _smooth_losses(loss_history)
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, smoothed=smoothed)
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=smoothed)
raw_img.save(str(output_dir / "loss_raw.png"))
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
loss_curve = _pil_to_tensor(smoothed_img)
# Reached only on normal completion (exception re-raises past this point)
generator.eval()
generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator}
loss_curve = _pil_to_tensor(smoothed_img)
return (patched, str(final_path), loss_curve)