fix: create LoRA params inside torch.enable_grad() to escape inference_mode

torch.enable_grad() re-enables grad tracking but nn.Parameters created while
torch.inference_mode() is active are inference tensors that can't enter autograd
regardless. Splitting into _train_inner() and calling it inside enable_grad()
ensures the deepcopy, apply_lora, and the training loop all run with a clean
autograd context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 22:36:28 +02:00
parent 505d445eb3
commit 849f31e2a6
+19 -2
View File
@@ -365,6 +365,25 @@ class SelvaLoraTrainer:
raise ValueError("[LoRA Trainer] No clips could be loaded.")
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
# Everything from here runs inside enable_grad: ComfyUI wraps nodes in
# inference_mode, and nn.Parameters created in that context are inference
# tensors that can't enter autograd even with requires_grad=True.
with torch.enable_grad():
return self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, warmup_steps,
grad_accum, save_every, resume_path, seed,
)
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, warmup_steps,
grad_accum, save_every, resume_path, seed,
):
# --- Prepare generator copy with LoRA ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
@@ -433,8 +452,6 @@ class SelvaLoraTrainer:
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps})\n", flush=True)
# ComfyUI runs nodes inside torch.no_grad() — re-enable gradients for training
with torch.enable_grad():
for step in range(start_step + 1, steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)