From 4f40e15db3a642cee0a83d61c6780f522916ab49 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Mar 2026 15:49:04 +0100 Subject: [PATCH] fix: guard model cleanup in try/finally and fix DiTWrapper comments - Wrap training loop in try/finally so _unapply_lora always runs. Without this, an exception mid-training would leave LoRALinear wrappers in the cached DiTWrapper; a subsequent training run would then apply LoRA on top of existing LoRA, silently doubling the effective rank. - Fix misleading comment: diffusion.model is DiTWrapper (not DiffusionTransformer). DiffusionTransformer is at diffusion.model.model; _apply_lora reaches it recursively but the direct attribute is the wrapper. Co-Authored-By: Claude Sonnet 4.6 --- nodes/lora_loader.py | 2 +- nodes/lora_trainer.py | 120 ++++++++++++++++++++++-------------------- 2 files changed, 63 insertions(+), 59 deletions(-) diff --git a/nodes/lora_loader.py b/nodes/lora_loader.py index f88d53e..b4bd831 100644 --- a/nodes/lora_loader.py +++ b/nodes/lora_loader.py @@ -95,7 +95,7 @@ class PrismAudioLoRALoader: # Merge LoRA weights in-place into the DiT's base linear layers. # ComfyUI re-executes the upstream ModelLoader on the next queue run # when inputs change, providing a fresh base model as needed. - dit = model["model"].model # DiffusionTransformer + dit = model["model"].model # DiTWrapper if strength == 0.0: print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True) diff --git a/nodes/lora_trainer.py b/nodes/lora_trainer.py index 0756776..c34fc33 100644 --- a/nodes/lora_trainer.py +++ b/nodes/lora_trainer.py @@ -176,7 +176,7 @@ class PrismAudioLoRATrainer: diffusion.pretransform.to(device) # Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B) - dit = diffusion.model # DiffusionTransformer + dit = diffusion.model # DiTWrapper for p in dit.parameters(): p.requires_grad_(False) @@ -205,76 +205,80 @@ class PrismAudioLoRATrainer: pbar = comfy.utils.ProgressBar(train_steps) - for step in range(1, train_steps + 1): - npz_path, audio_path = random.choice(pairs) + try: + for step in range(1, train_steps + 1): + npz_path, audio_path = random.choice(pairs) - with torch.no_grad(): - # Encode audio to latent space - audio = _load_audio(audio_path, device) - x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L] + with torch.no_grad(): + # Encode audio to latent space + audio = _load_audio(audio_path, device) + x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L] - # Build conditioning from features - metadata = (_load_metadata(npz_path, device, dtype),) - conditioning = diffusion.conditioner(metadata, device) - cond_inputs = diffusion.get_conditioning_inputs(conditioning) + # Build conditioning from features + metadata = (_load_metadata(npz_path, device, dtype),) + conditioning = diffusion.conditioner(metadata, device) + cond_inputs = diffusion.get_conditioning_inputs(conditioning) - # Rectified flow: interpolate between data and noise - t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1] - noise = torch.randn_like(x0) - # t expanded for broadcast: [1] -> [1, 1, 1] - t_bcast = t[:, None, None] - x_t = (1.0 - t_bcast) * x0 + t_bcast * noise - v_target = noise - x0 + # Rectified flow: interpolate between data and noise + t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1] + noise = torch.randn_like(x0) + # t expanded for broadcast: [1] -> [1, 1, 1] + t_bcast = t[:, None, None] + x_t = (1.0 - t_bcast) * x0 + t_bcast * noise + v_target = noise - x0 - with torch.amp.autocast(device_type=device.type, dtype=dtype): - v_pred = dit(x_t, t, - cfg_scale=1.0, - cfg_dropout_prob=cfg_dropout_prob, - **cond_inputs) + with torch.amp.autocast(device_type=device.type, dtype=dtype): + v_pred = dit(x_t, t, + cfg_scale=1.0, + cfg_dropout_prob=cfg_dropout_prob, + **cond_inputs) - loss = F.mse_loss(v_pred.float(), v_target.float()) + loss = F.mse_loss(v_pred.float(), v_target.float()) - if use_scaler: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() - optimizer.zero_grad() + if use_scaler: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + optimizer.zero_grad() - if step % 50 == 0: - print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True) + if step % 50 == 0: + print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True) - if step % save_every == 0: - ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors") - save_file(_get_lora_state_dict(dit), ckpt_path) - print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True) + if step % save_every == 0: + ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors") + save_file(_get_lora_state_dict(dit), ckpt_path) + print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True) - pbar.update(1) + pbar.update(1) - # Save final weights - save_file(_get_lora_state_dict(dit), output_path) + # Save final weights + save_file(_get_lora_state_dict(dit), output_path) - # Save config alongside weights so the loader knows the structure - config_path = output_path.replace(".safetensors", "_config.json") - with open(config_path, "w") as f: - json.dump({ - "rank": lora_rank, - "alpha": lora_alpha, - "target_modules": sorted(target_attrs), - }, f, indent=2) + # Save config alongside weights so the loader knows the structure + config_path = output_path.replace(".safetensors", "_config.json") + with open(config_path, "w") as f: + json.dump({ + "rank": lora_rank, + "alpha": lora_alpha, + "target_modules": sorted(target_attrs), + }, f, indent=2) - print(f"[PrismAudio] LoRA saved: {output_path}", flush=True) + print(f"[PrismAudio] LoRA saved: {output_path}", flush=True) - # Restore model to base state (remove LoRA wrappers, restore original linears) - dit.eval() - _unapply_lora(dit) + finally: + # Always restore model to base state — even on exception. + # Without this, LoRA wrappers would persist in the cached model and + # subsequent training runs would apply LoRA on top of existing LoRA. + dit.eval() + _unapply_lora(dit) - if strategy == "offload_to_cpu": - diffusion.model.to(get_offload_device()) - diffusion.conditioner.to(get_offload_device()) - diffusion.pretransform.to(get_offload_device()) - soft_empty_cache() + if strategy == "offload_to_cpu": + diffusion.model.to(get_offload_device()) + diffusion.conditioner.to(get_offload_device()) + diffusion.pretransform.to(get_offload_device()) + soft_empty_cache() return (output_path,)