fix: guard model cleanup in try/finally and fix DiTWrapper comments
- Wrap training loop in try/finally so _unapply_lora always runs. Without this, an exception mid-training would leave LoRALinear wrappers in the cached DiTWrapper; a subsequent training run would then apply LoRA on top of existing LoRA, silently doubling the effective rank. - Fix misleading comment: diffusion.model is DiTWrapper (not DiffusionTransformer). DiffusionTransformer is at diffusion.model.model; _apply_lora reaches it recursively but the direct attribute is the wrapper. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+62
-58
@@ -176,7 +176,7 @@ class PrismAudioLoRATrainer:
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||
dit = diffusion.model # DiffusionTransformer
|
||||
dit = diffusion.model # DiTWrapper
|
||||
for p in dit.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
@@ -205,76 +205,80 @@ class PrismAudioLoRATrainer:
|
||||
|
||||
pbar = comfy.utils.ProgressBar(train_steps)
|
||||
|
||||
for step in range(1, train_steps + 1):
|
||||
npz_path, audio_path = random.choice(pairs)
|
||||
try:
|
||||
for step in range(1, train_steps + 1):
|
||||
npz_path, audio_path = random.choice(pairs)
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode audio to latent space
|
||||
audio = _load_audio(audio_path, device)
|
||||
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||
with torch.no_grad():
|
||||
# Encode audio to latent space
|
||||
audio = _load_audio(audio_path, device)
|
||||
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||
|
||||
# Build conditioning from features
|
||||
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
# Build conditioning from features
|
||||
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Rectified flow: interpolate between data and noise
|
||||
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||
noise = torch.randn_like(x0)
|
||||
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||
t_bcast = t[:, None, None]
|
||||
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||
v_target = noise - x0
|
||||
# Rectified flow: interpolate between data and noise
|
||||
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||
noise = torch.randn_like(x0)
|
||||
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||
t_bcast = t[:, None, None]
|
||||
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||
v_target = noise - x0
|
||||
|
||||
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
v_pred = dit(x_t, t,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
**cond_inputs)
|
||||
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
v_pred = dit(x_t, t,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
**cond_inputs)
|
||||
|
||||
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||
|
||||
if use_scaler:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if use_scaler:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % 50 == 0:
|
||||
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||
if step % 50 == 0:
|
||||
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||
|
||||
if step % save_every == 0:
|
||||
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||
if step % save_every == 0:
|
||||
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||
|
||||
pbar.update(1)
|
||||
pbar.update(1)
|
||||
|
||||
# Save final weights
|
||||
save_file(_get_lora_state_dict(dit), output_path)
|
||||
# Save final weights
|
||||
save_file(_get_lora_state_dict(dit), output_path)
|
||||
|
||||
# Save config alongside weights so the loader knows the structure
|
||||
config_path = output_path.replace(".safetensors", "_config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump({
|
||||
"rank": lora_rank,
|
||||
"alpha": lora_alpha,
|
||||
"target_modules": sorted(target_attrs),
|
||||
}, f, indent=2)
|
||||
# Save config alongside weights so the loader knows the structure
|
||||
config_path = output_path.replace(".safetensors", "_config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump({
|
||||
"rank": lora_rank,
|
||||
"alpha": lora_alpha,
|
||||
"target_modules": sorted(target_attrs),
|
||||
}, f, indent=2)
|
||||
|
||||
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||
|
||||
# Restore model to base state (remove LoRA wrappers, restore original linears)
|
||||
dit.eval()
|
||||
_unapply_lora(dit)
|
||||
finally:
|
||||
# Always restore model to base state — even on exception.
|
||||
# Without this, LoRA wrappers would persist in the cached model and
|
||||
# subsequent training runs would apply LoRA on top of existing LoRA.
|
||||
dit.eval()
|
||||
_unapply_lora(dit)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
return (output_path,)
|
||||
|
||||
Reference in New Issue
Block a user