Files
ComfyUI-SelVA/nodes/lora_trainer.py
Ethanfel 4f40e15db3 fix: guard model cleanup in try/finally and fix DiTWrapper comments
- Wrap training loop in try/finally so _unapply_lora always runs.
  Without this, an exception mid-training would leave LoRALinear wrappers
  in the cached DiTWrapper; a subsequent training run would then apply LoRA
  on top of existing LoRA, silently doubling the effective rank.
- Fix misleading comment: diffusion.model is DiTWrapper (not DiffusionTransformer).
  DiffusionTransformer is at diffusion.model.model; _apply_lora reaches it
  recursively but the direct attribute is the wrapper.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-28 15:49:04 +01:00

285 lines
12 KiB
Python

import os
import math
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE,
get_device, get_offload_device, soft_empty_cache,
)
# ---------------------------------------------------------------------------
# LoRA primitives
# ---------------------------------------------------------------------------
class LoRALinear(nn.Module):
"""Low-rank adapter wrapping a frozen nn.Linear."""
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
super().__init__()
self.linear = linear
self.scale = alpha / rank
in_f, out_f = linear.in_features, linear.out_features
self.lora_A = nn.Linear(in_f, rank, bias=False)
self.lora_B = nn.Linear(rank, out_f, bias=False)
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x):
return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale
_TARGET_MODULE_PRESETS = {
"attn_only": {"to_q", "to_kv", "to_qkv", "to_out"},
"attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"},
"full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"},
}
def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float):
"""Recursively replace matching nn.Linear layers with LoRALinear."""
for name, child in list(module.named_children()):
if isinstance(child, nn.Linear) and name in target_attrs:
setattr(module, name, LoRALinear(child, rank, alpha))
else:
_apply_lora(child, target_attrs, rank, alpha)
def _unapply_lora(module: nn.Module):
"""Replace LoRALinear back with the original frozen Linear (no weight merge)."""
for name, child in list(module.named_children()):
if isinstance(child, LoRALinear):
child.linear.weight.requires_grad_(False)
setattr(module, name, child.linear)
else:
_unapply_lora(child)
def _get_lora_state_dict(module: nn.Module) -> dict:
"""Return only LoRA parameter tensors from a module's state dict."""
return {k: v for k, v in module.state_dict().items()
if "lora_A" in k or "lora_B" in k}
# ---------------------------------------------------------------------------
# Dataset helpers
# ---------------------------------------------------------------------------
_AUDIO_EXTS = (".wav", ".flac", ".mp3")
def _scan_dataset(dataset_dir: str):
"""Return list of (npz_path, audio_path) pairs matched by stem."""
pairs = []
for fname in os.listdir(dataset_dir):
if not fname.endswith(".npz"):
continue
stem = os.path.join(dataset_dir, fname[:-4])
for ext in _AUDIO_EXTS:
audio_path = stem + ext
if os.path.exists(audio_path):
pairs.append((stem + ".npz", audio_path))
break
return sorted(pairs)
def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor:
"""Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE."""
import torchaudio
waveform, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
if waveform.shape[0] == 1:
waveform = waveform.expand(2, -1)
elif waveform.shape[0] > 2:
waveform = waveform[:2]
return waveform.unsqueeze(0).to(device) # [1, 2, samples]
def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict:
"""Load .npz features into a conditioner metadata dict."""
import numpy as np
data = np.load(npz_path, allow_pickle=True)
video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype)
text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype)
sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype)
has_video = bool(video_feat.abs().sum() > 0)
return {
"video_features": video_feat,
"text_features": text_feat,
"sync_features": sync_feat,
"video_exist": torch.tensor(has_video),
}
# ---------------------------------------------------------------------------
# Trainer node
# ---------------------------------------------------------------------------
class PrismAudioLoRATrainer:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}),
"output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}),
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}),
"lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}),
"target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}),
"learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}),
"train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}),
"cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}),
"save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("lora_path",)
FUNCTION = "train"
CATEGORY = PRISMAUDIO_CATEGORY
def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha,
target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed):
from safetensors.torch import save_file
device = get_device()
dtype = model["dtype"]
diffusion = model["model"]
strategy = model["strategy"]
torch.manual_seed(seed)
random.seed(seed)
# Scan dataset
pairs = _scan_dataset(dataset_dir)
if not pairs:
raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}")
print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True)
# Resolve output path
if not output_path:
import folder_paths
out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora")
os.makedirs(out_dir, exist_ok=True)
output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors")
# Move model to device
diffusion.model.to(device)
diffusion.conditioner.to(device)
diffusion.pretransform.to(device)
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
dit = diffusion.model # DiTWrapper
for p in dit.parameters():
p.requires_grad_(False)
target_attrs = _TARGET_MODULE_PRESETS[target_modules]
_apply_lora(dit, target_attrs, lora_rank, lora_alpha)
# Cast LoRA params to model dtype and move to device
for m in dit.modules():
if isinstance(m, LoRALinear):
m.lora_A.to(device=device, dtype=dtype)
m.lora_B.to(device=device, dtype=dtype)
trainable = [p for p in dit.parameters() if p.requires_grad]
n_params = sum(p.numel() for p in trainable)
print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True)
diffusion.conditioner.eval()
diffusion.pretransform.eval()
dit.train()
optimizer = torch.optim.AdamW(trainable, lr=learning_rate)
# GradScaler for fp16 to prevent underflow
use_scaler = (dtype == torch.float16)
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
pbar = comfy.utils.ProgressBar(train_steps)
try:
for step in range(1, train_steps + 1):
npz_path, audio_path = random.choice(pairs)
with torch.no_grad():
# Encode audio to latent space
audio = _load_audio(audio_path, device)
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
# Build conditioning from features
metadata = (_load_metadata(npz_path, device, dtype),)
conditioning = diffusion.conditioner(metadata, device)
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Rectified flow: interpolate between data and noise
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
noise = torch.randn_like(x0)
# t expanded for broadcast: [1] -> [1, 1, 1]
t_bcast = t[:, None, None]
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
v_target = noise - x0
with torch.amp.autocast(device_type=device.type, dtype=dtype):
v_pred = dit(x_t, t,
cfg_scale=1.0,
cfg_dropout_prob=cfg_dropout_prob,
**cond_inputs)
loss = F.mse_loss(v_pred.float(), v_target.float())
if use_scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % 50 == 0:
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
if step % save_every == 0:
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
save_file(_get_lora_state_dict(dit), ckpt_path)
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
pbar.update(1)
# Save final weights
save_file(_get_lora_state_dict(dit), output_path)
# Save config alongside weights so the loader knows the structure
config_path = output_path.replace(".safetensors", "_config.json")
with open(config_path, "w") as f:
json.dump({
"rank": lora_rank,
"alpha": lora_alpha,
"target_modules": sorted(target_attrs),
}, f, indent=2)
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
finally:
# Always restore model to base state — even on exception.
# Without this, LoRA wrappers would persist in the cached model and
# subsequent training runs would apply LoRA on top of existing LoRA.
dit.eval()
_unapply_lora(dit)
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
return (output_path,)