diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 58fe2cf..273bd4a 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -507,100 +507,116 @@ class SelvaLoraTrainer: f"(step {start_step + 1} → {steps}, batch_size={batch_size}, " f"timestep_mode={timestep_mode})\n", flush=True) - for step in range(start_step + 1, steps + 1): - batch = random.choices(dataset, k=batch_size) - x1_list, clip_list, sync_list, text_list = zip(*batch) + last_step = start_step + completed = False + try: + for step in range(start_step + 1, steps + 1): + batch = random.choices(dataset, k=batch_size) + x1_list, clip_list, sync_list, text_list = zip(*batch) - x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) - clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype) - sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) - text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype) + x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) + clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype) + sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) + text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype) - generator.normalize(x1) + generator.normalize(x1) - if timestep_mode == "logit_normal": - u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma - t = torch.sigmoid(u) - else: - t = torch.rand(batch_size, device=device, dtype=dtype) - x0 = torch.randn_like(x1) - xt = fm.get_conditional_flow(x0, x1, t) + if timestep_mode == "logit_normal": + u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma + t = torch.sigmoid(u) + else: + t = torch.rand(batch_size, device=device, dtype=dtype) + x0 = torch.randn_like(x1) + xt = fm.get_conditional_flow(x0, x1, t) - v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) - loss = fm.loss(v_pred, x0, x1).mean() / grad_accum - loss.backward() - running_loss += loss.item() * grad_accum + v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) + loss = fm.loss(v_pred, x0, x1).mean() / grad_accum + loss.backward() + running_loss += loss.item() * grad_accum - if step % grad_accum == 0: - torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) - optimizer.step() - scheduler.step() - optimizer.zero_grad() + if step % grad_accum == 0: + torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() - if step % log_interval == 0: - avg = running_loss / log_interval - loss_history.append(avg) - lr_now = scheduler.get_last_lr()[0] - print(f"[LoRA Trainer] step {step:5d}/{steps} " - f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True) - running_loss = 0.0 + if step % log_interval == 0: + avg = running_loss / log_interval + loss_history.append(avg) + lr_now = scheduler.get_last_lr()[0] + print(f"[LoRA Trainer] step {step:5d}/{steps} " + f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True) + running_loss = 0.0 - # Live preview: send updated loss curve to ComfyUI frontend - preview_img = _draw_loss_curve(loss_history, log_interval, start_step, - smoothed=_smooth_losses(loss_history)) - pbar_train.update_absolute( - step - start_step, remaining, ("JPEG", preview_img, 800) - ) + # Live preview: send updated loss curve to ComfyUI frontend + preview_img = _draw_loss_curve(loss_history, log_interval, start_step, + smoothed=_smooth_losses(loss_history)) + pbar_train.update_absolute( + step - start_step, remaining, ("JPEG", preview_img, 800) + ) - if step % save_every == 0 or step == steps: - ckpt_path = output_dir / f"adapter_step{step:05d}.pt" - torch.save({ - "state_dict": get_lora_state_dict(generator), - "optimizer": optimizer.state_dict(), - "scheduler": scheduler.state_dict(), - "step": step, - "meta": meta, - }, ckpt_path) - print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) + if step % save_every == 0 or step == steps: + ckpt_path = output_dir / f"adapter_step{step:05d}.pt" + torch.save({ + "state_dict": get_lora_state_dict(generator), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "meta": meta, + }, ckpt_path) + print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) - # Save a quick eval sample next to the checkpoint - wav, sr = _eval_sample(generator, feature_utils_orig, - dataset, seq_cfg, device, dtype) - if wav is not None: - wav_path = output_dir / f"sample_step{step:05d}.wav" - try: - torchaudio.save(str(wav_path), wav, sr) - except RuntimeError: - import soundfile as sf - sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) - print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) + # Save a quick eval sample next to the checkpoint + wav, sr = _eval_sample(generator, feature_utils_orig, + dataset, seq_cfg, device, dtype) + if wav is not None: + wav_path = output_dir / f"sample_step{step:05d}.wav" + try: + torchaudio.save(str(wav_path), wav, sr) + except RuntimeError: + import soundfile as sf + sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) + print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) - pbar_train.update(1) + last_step = step + pbar_train.update(1) - # Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible) - # Increment filename if a previous final already exists (resume case) - final_path = output_dir / "adapter_final.pt" - if final_path.exists(): - i = 1 - while (output_dir / f"adapter_final_{i:03d}.pt").exists(): - i += 1 - final_path = output_dir / f"adapter_final_{i:03d}.pt" - torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path) - (output_dir / "meta.json").write_text(json.dumps(meta, indent=2)) - print(f"\n[LoRA Trainer] Done. Adapter saved to {final_path}", flush=True) + completed = True - # --- Return patched model --- + finally: + # Save adapter and loss curves whether training completed or was cancelled. + # Skip if we never completed a single step (nothing useful to save). + if loss_history: + if completed: + # Normal completion — use adapter_final.pt (increment if exists) + final_path = output_dir / "adapter_final.pt" + if final_path.exists(): + i = 1 + while (output_dir / f"adapter_final_{i:03d}.pt").exists(): + i += 1 + final_path = output_dir / f"adapter_final_{i:03d}.pt" + label = "Done" + else: + # Cancelled — include the step number so the file is useful for resume + final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt" + label = f"Cancelled at step {last_step}" + + torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path) + (output_dir / "meta.json").write_text(json.dumps(meta, indent=2)) + print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True) + + smoothed = _smooth_losses(loss_history) + raw_img = _draw_loss_curve(loss_history, log_interval, start_step) + smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, + smoothed=smoothed) + raw_img.save(str(output_dir / "loss_raw.png")) + smoothed_img.save(str(output_dir / "loss_smoothed.png")) + print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True) + + # Reached only on normal completion (exception re-raises past this point) generator.eval() generator.to(next(model["generator"].parameters()).device) patched = {**model, "generator": generator} - smoothed = _smooth_losses(loss_history) - raw_img = _draw_loss_curve(loss_history, log_interval, start_step) - smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step, smoothed=smoothed) - raw_img.save(str(output_dir / "loss_raw.png")) - smoothed_img.save(str(output_dir / "loss_smoothed.png")) - print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True) - loss_curve = _pil_to_tensor(smoothed_img) - return (patched, str(final_path), loss_curve)