fix: create LoRA params inside torch.enable_grad() to escape inference_mode
torch.enable_grad() re-enables grad tracking but nn.Parameters created while torch.inference_mode() is active are inference tensors that can't enter autograd regardless. Splitting into _train_inner() and calling it inside enable_grad() ensures the deepcopy, apply_lora, and the training loop all run with a clean autograd context. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -365,6 +365,25 @@ class SelvaLoraTrainer:
|
|||||||
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||||
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||||
|
|
||||||
|
# Everything from here runs inside enable_grad: ComfyUI wraps nodes in
|
||||||
|
# inference_mode, and nn.Parameters created in that context are inference
|
||||||
|
# tensors that can't enter autograd even with requires_grad=True.
|
||||||
|
with torch.enable_grad():
|
||||||
|
return self._train_inner(
|
||||||
|
model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, warmup_steps,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _train_inner(
|
||||||
|
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, warmup_steps,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|
||||||
@@ -433,8 +452,6 @@ class SelvaLoraTrainer:
|
|||||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
f"(step {start_step + 1} → {steps})\n", flush=True)
|
f"(step {start_step + 1} → {steps})\n", flush=True)
|
||||||
|
|
||||||
# ComfyUI runs nodes inside torch.no_grad() — re-enable gradients for training
|
|
||||||
with torch.enable_grad():
|
|
||||||
for step in range(start_step + 1, steps + 1):
|
for step in range(start_step + 1, steps + 1):
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user