fix: create LoRA params inside torch.enable_grad() to escape inference_mode
torch.enable_grad() re-enables grad tracking but nn.Parameters created while torch.inference_mode() is active are inference tensors that can't enter autograd regardless. Splitting into _train_inner() and calling it inside enable_grad() ensures the deepcopy, apply_lora, and the training loop all run with a clean autograd context. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+68
-51
@@ -365,6 +365,25 @@ class SelvaLoraTrainer:
|
|||||||
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||||
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||||
|
|
||||||
|
# Everything from here runs inside enable_grad: ComfyUI wraps nodes in
|
||||||
|
# inference_mode, and nn.Parameters created in that context are inference
|
||||||
|
# tensors that can't enter autograd even with requires_grad=True.
|
||||||
|
with torch.enable_grad():
|
||||||
|
return self._train_inner(
|
||||||
|
model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, warmup_steps,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _train_inner(
|
||||||
|
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, warmup_steps,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|
||||||
@@ -433,67 +452,65 @@ class SelvaLoraTrainer:
|
|||||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
f"(step {start_step + 1} → {steps})\n", flush=True)
|
f"(step {start_step + 1} → {steps})\n", flush=True)
|
||||||
|
|
||||||
# ComfyUI runs nodes inside torch.no_grad() — re-enable gradients for training
|
for step in range(start_step + 1, steps + 1):
|
||||||
with torch.enable_grad():
|
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||||
for step in range(start_step + 1, steps + 1):
|
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
|
||||||
|
|
||||||
x1 = x1_cpu.to(device, dtype)
|
x1 = x1_cpu.to(device, dtype)
|
||||||
clip_f = clip_f_cpu.to(device, dtype)
|
clip_f = clip_f_cpu.to(device, dtype)
|
||||||
sync_f = sync_f_cpu.to(device, dtype)
|
sync_f = sync_f_cpu.to(device, dtype)
|
||||||
text_clip = text_clip_cpu.to(device, dtype)
|
text_clip = text_clip_cpu.to(device, dtype)
|
||||||
|
|
||||||
generator.normalize(x1)
|
generator.normalize(x1)
|
||||||
|
|
||||||
t = torch.rand(1, device=device, dtype=dtype)
|
t = torch.rand(1, device=device, dtype=dtype)
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||||||
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
||||||
loss.backward()
|
loss.backward()
|
||||||
running_loss += loss.item() * grad_accum
|
running_loss += loss.item() * grad_accum
|
||||||
|
|
||||||
if step % grad_accum == 0:
|
if step % grad_accum == 0:
|
||||||
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if step % log_interval == 0:
|
if step % log_interval == 0:
|
||||||
avg = running_loss / log_interval
|
avg = running_loss / log_interval
|
||||||
loss_history.append(avg)
|
loss_history.append(avg)
|
||||||
lr_now = scheduler.get_last_lr()[0]
|
lr_now = scheduler.get_last_lr()[0]
|
||||||
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
||||||
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
|
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
# Live preview: send updated loss curve to ComfyUI frontend
|
# Live preview: send updated loss curve to ComfyUI frontend
|
||||||
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||||
pbar_train.update_absolute(
|
pbar_train.update_absolute(
|
||||||
step - start_step, remaining, ("JPEG", preview_img, 85)
|
step - start_step, remaining, ("JPEG", preview_img, 85)
|
||||||
)
|
)
|
||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
if step % save_every == 0 or step == steps:
|
||||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
"state_dict": get_lora_state_dict(generator),
|
"state_dict": get_lora_state_dict(generator),
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"scheduler": scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
"step": step,
|
"step": step,
|
||||||
"meta": meta,
|
"meta": meta,
|
||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||||||
|
|
||||||
# Save a quick eval sample next to the checkpoint
|
# Save a quick eval sample next to the checkpoint
|
||||||
wav, sr = _eval_sample(generator, feature_utils_orig,
|
wav, sr = _eval_sample(generator, feature_utils_orig,
|
||||||
dataset, seq_cfg, device, dtype)
|
dataset, seq_cfg, device, dtype)
|
||||||
if wav is not None:
|
if wav is not None:
|
||||||
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
torchaudio.save(str(wav_path), wav, sr)
|
||||||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||||
|
|
||||||
pbar_train.update(1)
|
pbar_train.update(1)
|
||||||
|
|
||||||
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
|
||||||
final_path = output_dir / "adapter_final.pt"
|
final_path = output_dir / "adapter_final.pt"
|
||||||
|
|||||||
Reference in New Issue
Block a user