fix: wrap training loop in torch.enable_grad()

ComfyUI executes all nodes inside torch.no_grad(), which prevents gradient
tracking and makes loss.backward() fail. torch.enable_grad() overrides it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 22:32:00 +02:00
parent 8fade1b0e3
commit 505d445eb3
+51 -49
View File
@@ -433,65 +433,67 @@ class SelvaLoraTrainer:
print(f"\n[LoRA Trainer] Training {remaining} steps " print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps})\n", flush=True) f"(step {start_step + 1}{steps})\n", flush=True)
for step in range(start_step + 1, steps + 1): # ComfyUI runs nodes inside torch.no_grad() — re-enable gradients for training
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) with torch.enable_grad():
for step in range(start_step + 1, steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
x1 = x1_cpu.to(device, dtype) x1 = x1_cpu.to(device, dtype)
clip_f = clip_f_cpu.to(device, dtype) clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype) sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype)
generator.normalize(x1) generator.normalize(x1)
t = torch.rand(1, device=device, dtype=dtype) t = torch.rand(1, device=device, dtype=dtype)
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
loss.backward() loss.backward()
running_loss += loss.item() * grad_accum running_loss += loss.item() * grad_accum
if step % grad_accum == 0: if step % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
if step % log_interval == 0: if step % log_interval == 0:
avg = running_loss / log_interval avg = running_loss / log_interval
loss_history.append(avg) loss_history.append(avg)
lr_now = scheduler.get_last_lr()[0] lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA Trainer] step {step:5d}/{steps} " print(f"[LoRA Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e}", flush=True) f"loss={avg:.4f} lr={lr_now:.2e}", flush=True)
running_loss = 0.0 running_loss = 0.0
# Live preview: send updated loss curve to ComfyUI frontend # Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step) preview_img = _draw_loss_curve(loss_history, log_interval, start_step)
pbar_train.update_absolute( pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 85) step - start_step, remaining, ("JPEG", preview_img, 85)
) )
if step % save_every == 0 or step == steps: if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt" ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({ torch.save({
"state_dict": get_lora_state_dict(generator), "state_dict": get_lora_state_dict(generator),
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(), "scheduler": scheduler.state_dict(),
"step": step, "step": step,
"meta": meta, "meta": meta,
}, ckpt_path) }, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True) print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample next to the checkpoint # Save a quick eval sample next to the checkpoint
wav, sr = _eval_sample(generator, feature_utils_orig, wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype) dataset, seq_cfg, device, dtype)
if wav is not None: if wav is not None:
wav_path = output_dir / f"sample_step{step:05d}.wav" wav_path = output_dir / f"sample_step{step:05d}.wav"
torchaudio.save(str(wav_path), wav, sr) torchaudio.save(str(wav_path), wav, sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True) print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
pbar_train.update(1) pbar_train.update(1)
# Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible) # Save inference adapter (state_dict + meta only — SelvaLoraLoader compatible)
final_path = output_dir / "adapter_final.pt" final_path = output_dir / "adapter_final.pt"