From 849f31e2a66084d7b45853bd6bfeb9dfe2698cda Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 22:36:28 +0200 Subject: [PATCH] fix: create LoRA params inside torch.enable_grad() to escape inference_mode torch.enable_grad() re-enables grad tracking but nn.Parameters created while torch.inference_mode() is active are inference tensors that can't enter autograd regardless. Splitting into _train_inner() and calling it inside enable_grad() ensures the deepcopy, apply_lora, and the training loop all run with a clean autograd context. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_trainer.py | 119 ++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 51 deletions(-) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 3fb8721..6ebd4b3 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -365,6 +365,25 @@ class SelvaLoraTrainer: raise ValueError("[LoRA Trainer] No clips could be loaded.") print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True) + # Everything from here runs inside enable_grad: ComfyUI wraps nodes in + # inference_mode, and nn.Parameters created in that context are inference + # tensors that can't enter autograd even with requires_grad=True. + with torch.enable_grad(): + return self._train_inner( + model, dataset, feature_utils_orig, seq_cfg, + device, dtype, variant, mode, + data_dir, output_dir, steps, rank, lr, + alpha_val, target_suffixes, warmup_steps, + grad_accum, save_every, resume_path, seed, + ) + + def _train_inner( + self, model, dataset, feature_utils_orig, seq_cfg, + device, dtype, variant, mode, + data_dir, output_dir, steps, rank, lr, + alpha_val, target_suffixes, warmup_steps, + grad_accum, save_every, resume_path, seed, + ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) @@ -433,67 +452,65 @@ class SelvaLoraTrainer: print(f"\n[LoRA Trainer] Training {remaining} steps " f"(step {start_step + 1} → {steps})\n", flush=True) - # ComfyUI runs nodes inside torch.no_grad() — re-enable gradients for training - with torch.enable_grad(): - for step in range(start_step + 1, steps + 1): - x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) + for step in range(start_step + 1, steps + 1): + x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) - x1 = x1_cpu.to(device, dtype) - clip_f = clip_f_cpu.to(device, dtype) - sync_f = sync_f_cpu.to(device, dtype) - text_clip = text_clip_cpu.to(device, dtype) + x1 = x1_cpu.to(device, dtype) + clip_f = clip_f_cpu.to(device, dtype) + sync_f = sync_f_cpu.to(device, dtype) + text_clip = text_clip_cpu.to(device, dtype) - generator.normalize(x1) + generator.normalize(x1) - t = torch.rand(1, device=device, dtype=dtype) - x0 = torch.randn_like(x1) - xt = fm.get_conditional_flow(x0, x1, t) + t = torch.rand(1, device=device, dtype=dtype) + x0 = torch.randn_like(x1) + xt = fm.get_conditional_flow(x0, x1, t) - v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) - loss = fm.loss(v_pred, x0, x1).mean() / grad_accum - loss.backward() - running_loss += loss.item() * grad_accum + v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) + loss = fm.loss(v_pred, x0, x1).mean() / grad_accum + loss.backward() + running_loss += loss.item() * grad_accum - if step % grad_accum == 0: - torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) - optimizer.step() - scheduler.step() - optimizer.zero_grad() + if step % grad_accum == 0: + torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() - if step % log_interval == 0: - avg = running_loss / log_interval - loss_history.append(avg) - lr_now = scheduler.get_last_lr()[0] - print(f"[LoRA Trainer] step {step:5d}/{steps} " - f"loss={avg:.4f} lr={lr_now:.2e}", flush=True) - running_loss = 0.0 + if step % log_interval == 0: + avg = running_loss / log_interval + loss_history.append(avg) + lr_now = scheduler.get_last_lr()[0] + print(f"[LoRA Trainer] step {step:5d}/{steps} " + f"loss={avg:.4f} lr={lr_now:.2e}", flush=True) + running_loss = 0.0 - # Live preview: send updated loss curve to ComfyUI frontend - preview_img = _draw_loss_curve(loss_history, log_interval, start_step) - pbar_train.update_absolute( - step - start_step, remaining, ("JPEG", preview_img, 85) - ) + # Live preview: send updated loss curve to ComfyUI frontend + preview_img = _draw_loss_curve(loss_history, log_interval, start_step) + pbar_train.update_absolute( + step - start_step, remaining, ("JPEG", preview_img, 85) + ) - if step % save_every == 0 or step == steps: - ckpt_path = output_dir / f"adapter_step{step:05d}.pt" - torch.save({ - "state_dict": get_lora_state_dict(generator), - "optimizer": optimizer.state_dict(), - "scheduler": scheduler.state_dict(), - "step": step, - "meta": meta, - }, ckpt_path) - print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) + if step % save_every == 0 or step == steps: + ckpt_path = output_dir / f"adapter_step{step:05d}.pt" + torch.save({ + "state_dict": get_lora_state_dict(generator), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "meta": meta, + }, ckpt_path) + print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) - # Save a quick eval sample next to the checkpoint - wav, sr = _eval_sample(generator, feature_utils_orig, - dataset, seq_cfg, device, dtype) - if wav is not None: - wav_path = output_dir / f"sample_step{step:05d}.wav" - torchaudio.save(str(wav_path), wav, sr) - print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) + # Save a quick eval sample next to the checkpoint + wav, sr = _eval_sample(generator, feature_utils_orig, + dataset, seq_cfg, device, dtype) + if wav is not None: + wav_path = output_dir / f"sample_step{step:05d}.wav" + torchaudio.save(str(wav_path), wav, sr) + print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) - pbar_train.update(1) + pbar_train.update(1) # Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible) final_path = output_dir / "adapter_final.pt"