fix: save adapter and loss curves on cancel, not only on normal completion
Wraps training loop in try/finally so adapter_final.pt and loss PNGs are always written. On cancellation the adapter is named adapter_cancelled_stepXXXXX.pt so it can be used with --resume. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+26
-10
@@ -507,6 +507,9 @@ class SelvaLoraTrainer:
|
||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||||
f"timestep_mode={timestep_mode})\n", flush=True)
|
||||
|
||||
last_step = start_step
|
||||
completed = False
|
||||
try:
|
||||
for step in range(start_step + 1, steps + 1):
|
||||
batch = random.choices(dataset, k=batch_size)
|
||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||
@@ -575,32 +578,45 @@ class SelvaLoraTrainer:
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||
|
||||
last_step = step
|
||||
pbar_train.update(1)
|
||||
|
||||
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
||||
# Increment filename if a previous final already exists (resume case)
|
||||
completed = True
|
||||
|
||||
finally:
|
||||
# Save adapter and loss curves whether training completed or was cancelled.
|
||||
# Skip if we never completed a single step (nothing useful to save).
|
||||
if loss_history:
|
||||
if completed:
|
||||
# Normal completion — use adapter_final.pt (increment if exists)
|
||||
final_path = output_dir / "adapter_final.pt"
|
||||
if final_path.exists():
|
||||
i = 1
|
||||
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
||||
i += 1
|
||||
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
||||
label = "Done"
|
||||
else:
|
||||
# Cancelled — include the step number so the file is useful for resume
|
||||
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||||
label = f"Cancelled at step {last_step}"
|
||||
|
||||
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
|
||||
|
||||
# --- Return patched model ---
|
||||
generator.eval()
|
||||
generator.to(next(model["generator"].parameters()).device)
|
||||
patched = {**model, "generator": generator}
|
||||
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||||
|
||||
smoothed = _smooth_losses(loss_history)
|
||||
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, smoothed=smoothed)
|
||||
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||||
smoothed=smoothed)
|
||||
raw_img.save(str(output_dir / "loss_raw.png"))
|
||||
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
||||
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
||||
|
||||
loss_curve = _pil_to_tensor(smoothed_img)
|
||||
# Reached only on normal completion (exception re-raises past this point)
|
||||
generator.eval()
|
||||
generator.to(next(model["generator"].parameters()).device)
|
||||
patched = {**model, "generator": generator}
|
||||
|
||||
loss_curve = _pil_to_tensor(smoothed_img)
|
||||
return (patched, str(final_path), loss_curve)
|
||||
|
||||
Reference in New Issue
Block a user