fix: guard model cleanup in try/finally and fix DiTWrapper comments
- Wrap training loop in try/finally so _unapply_lora always runs. Without this, an exception mid-training would leave LoRALinear wrappers in the cached DiTWrapper; a subsequent training run would then apply LoRA on top of existing LoRA, silently doubling the effective rank. - Fix misleading comment: diffusion.model is DiTWrapper (not DiffusionTransformer). DiffusionTransformer is at diffusion.model.model; _apply_lora reaches it recursively but the direct attribute is the wrapper. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -95,7 +95,7 @@ class PrismAudioLoRALoader:
|
|||||||
# Merge LoRA weights in-place into the DiT's base linear layers.
|
# Merge LoRA weights in-place into the DiT's base linear layers.
|
||||||
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
||||||
# when inputs change, providing a fresh base model as needed.
|
# when inputs change, providing a fresh base model as needed.
|
||||||
dit = model["model"].model # DiffusionTransformer
|
dit = model["model"].model # DiTWrapper
|
||||||
|
|
||||||
if strength == 0.0:
|
if strength == 0.0:
|
||||||
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class PrismAudioLoRATrainer:
|
|||||||
diffusion.pretransform.to(device)
|
diffusion.pretransform.to(device)
|
||||||
|
|
||||||
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||||
dit = diffusion.model # DiffusionTransformer
|
dit = diffusion.model # DiTWrapper
|
||||||
for p in dit.parameters():
|
for p in dit.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
@@ -205,6 +205,7 @@ class PrismAudioLoRATrainer:
|
|||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(train_steps)
|
pbar = comfy.utils.ProgressBar(train_steps)
|
||||||
|
|
||||||
|
try:
|
||||||
for step in range(1, train_steps + 1):
|
for step in range(1, train_steps + 1):
|
||||||
npz_path, audio_path = random.choice(pairs)
|
npz_path, audio_path = random.choice(pairs)
|
||||||
|
|
||||||
@@ -267,7 +268,10 @@ class PrismAudioLoRATrainer:
|
|||||||
|
|
||||||
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||||
|
|
||||||
# Restore model to base state (remove LoRA wrappers, restore original linears)
|
finally:
|
||||||
|
# Always restore model to base state — even on exception.
|
||||||
|
# Without this, LoRA wrappers would persist in the cached model and
|
||||||
|
# subsequent training runs would apply LoRA on top of existing LoRA.
|
||||||
dit.eval()
|
dit.eval()
|
||||||
_unapply_lora(dit)
|
_unapply_lora(dit)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user