fix: save adapter and loss curves on cancel, not only on normal completion

Wraps training loop in try/finally so adapter_final.pt and loss PNGs are
always written. On cancellation the adapter is named
adapter_cancelled_stepXXXXX.pt so it can be used with --resume.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 01:06:44 +02:00
parent d83632e754
commit fec8eaac95
+95 -79
View File
@@ -507,100 +507,116 @@ class SelvaLoraTrainer:
f"(step {start_step + 1}{steps}, batch_size={batch_size}, " f"(step {start_step + 1}{steps}, batch_size={batch_size}, "
f"timestep_mode={timestep_mode})\n", flush=True) f"timestep_mode={timestep_mode})\n", flush=True)
for step in range(start_step + 1, steps + 1): last_step = start_step
batch = random.choices(dataset, k=batch_size) completed = False
x1_list, clip_list, sync_list, text_list = zip(*batch) try:
for step in range(start_step + 1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype) clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype) text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
generator.normalize(x1) generator.normalize(x1)
if timestep_mode == "logit_normal": if timestep_mode == "logit_normal":
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
t = torch.sigmoid(u) t = torch.sigmoid(u)
else: else:
t = torch.rand(batch_size, device=device, dtype=dtype) t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
loss.backward() loss.backward()
running_loss += loss.item() * grad_accum running_loss += loss.item() * grad_accum
if step % grad_accum == 0: if step % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
if step % log_interval == 0: if step % log_interval == 0:
avg = running_loss / log_interval avg = running_loss / log_interval
loss_history.append(avg) loss_history.append(avg)
lr_now = scheduler.get_last_lr()[0] lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA Trainer] step {step:5d}/{steps} " print(f"[LoRA Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True) f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True)
running_loss = 0.0 running_loss = 0.0
# Live preview: send updated loss curve to ComfyUI frontend # Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step, preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=_smooth_losses(loss_history)) smoothed=_smooth_losses(loss_history))
pbar_train.update_absolute( pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 800) step - start_step, remaining, ("JPEG", preview_img, 800)
) )
if step % save_every == 0 or step == steps: if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt" ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({ torch.save({
"state_dict": get_lora_state_dict(generator), "state_dict": get_lora_state_dict(generator),
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(), "scheduler": scheduler.state_dict(),
"step": step, "step": step,
"meta": meta, "meta": meta,
}, ckpt_path) }, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample next to the checkpoint # Save a quick eval sample next to the checkpoint
wav, sr = _eval_sample(generator, feature_utils_orig, wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype) dataset, seq_cfg, device, dtype)
if wav is not None: if wav is not None:
wav_path = output_dir / f"sample_step{step:05d}.wav" wav_path = output_dir / f"sample_step{step:05d}.wav"
try: try:
torchaudio.save(str(wav_path), wav, sr) torchaudio.save(str(wav_path), wav, sr)
except RuntimeError: except RuntimeError:
import soundfile as sf import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
pbar_train.update(1) last_step = step
pbar_train.update(1)
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible) completed = True
# Increment filename if a previous final already exists (resume case)
final_path = output_dir / "adapter_final.pt"
if final_path.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final_path = output_dir / f"adapter_final_{i:03d}.pt"
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True)
# --- Return patched model --- finally:
# Save adapter and loss curves whether training completed or was cancelled.
# Skip if we never completed a single step (nothing useful to save).
if loss_history:
if completed:
# Normal completion — use adapter_final.pt (increment if exists)
final_path = output_dir / "adapter_final.pt"
if final_path.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final_path = output_dir / f"adapter_final_{i:03d}.pt"
label = "Done"
else:
# Cancelled — include the step number so the file is useful for resume
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
label = f"Cancelled at step {last_step}"
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
smoothed = _smooth_losses(loss_history)
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=smoothed)
raw_img.save(str(output_dir / "loss_raw.png"))
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
# Reached only on normal completion (exception re-raises past this point)
generator.eval() generator.eval()
generator.to(next(model["generator"].parameters()).device) generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator} patched = {**model, "generator": generator}
smoothed = _smooth_losses(loss_history)
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, smoothed=smoothed)
raw_img.save(str(output_dir / "loss_raw.png"))
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
loss_curve = _pil_to_tensor(smoothed_img) loss_curve = _pil_to_tensor(smoothed_img)
return (patched, str(final_path), loss_curve) return (patched, str(final_path), loss_curve)