feat: LoRA trainer and loader nodes for PrismAudio DiT fine-tuning
Adds PrismAudioLoRATrainer and PrismAudioLoRALoader nodes enabling low-rank adaptation of the DiT on paired (video features + audio) datasets. - LoRALinear wraps nn.Linear with trainable lora_A/lora_B matrices - Rectified flow training loop with fp16 GradScaler, AdamW, cfg dropout - Checkpoint saving every N steps + _config.json metadata alongside weights - _unapply_lora restores base model state after training completes - Weight-merge loader: delta_W added in-place, no deep copy overhead - Three target presets: attn_only, attn_ffn (default), full Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,8 @@ _NODES = {
|
||||
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
||||
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
||||
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
||||
"PrismAudioLoRATrainer": (".lora_trainer", "PrismAudioLoRATrainer", "PrismAudio LoRA Trainer"),
|
||||
"PrismAudioLoRALoader": (".lora_loader", "PrismAudioLoRALoader", "PrismAudio LoRA Loader"),
|
||||
}
|
||||
|
||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
|
||||
|
||||
def _merge_lora_weights(dit: nn.Module, lora_state: dict, rank: int, alpha: float, strength: float):
|
||||
"""Add LoRA delta weights directly into the base model's nn.Linear tensors.
|
||||
|
||||
delta_W = lora_B @ lora_A * scale * strength
|
||||
applied as: linear.weight += delta_W
|
||||
|
||||
This is equivalent to LoRALinear at inference but requires no wrapper,
|
||||
no extra memory, and no change to the model's forward call graph.
|
||||
"""
|
||||
scale = (alpha / rank) * strength
|
||||
|
||||
# Group saved keys by module path
|
||||
a_map = {
|
||||
k.replace(".lora_A.weight", ""): v
|
||||
for k, v in lora_state.items() if k.endswith("lora_A.weight")
|
||||
}
|
||||
b_map = {
|
||||
k.replace(".lora_B.weight", ""): v
|
||||
for k, v in lora_state.items() if k.endswith("lora_B.weight")
|
||||
}
|
||||
|
||||
merged = 0
|
||||
for path, lora_A in a_map.items():
|
||||
if path not in b_map:
|
||||
print(f"[PrismAudio] LoRA merge: missing lora_B for {path}, skipping", flush=True)
|
||||
continue
|
||||
lora_B = b_map[path] # [out_features, rank]
|
||||
# delta_W: [out_features, in_features]
|
||||
delta_W = (lora_B.float() @ lora_A.float()) * scale
|
||||
|
||||
# Navigate to the parent module using PyTorch's get_submodule
|
||||
*parent_parts, child_name = path.split(".")
|
||||
try:
|
||||
parent = dit.get_submodule(".".join(parent_parts)) if parent_parts else dit
|
||||
except AttributeError as e:
|
||||
print(f"[PrismAudio] LoRA merge: could not find module '{path}': {e}", flush=True)
|
||||
continue
|
||||
|
||||
linear = getattr(parent, child_name, None)
|
||||
if not isinstance(linear, nn.Linear):
|
||||
print(f"[PrismAudio] LoRA merge: expected nn.Linear at '{path}', got {type(linear)}", flush=True)
|
||||
continue
|
||||
|
||||
linear.weight.data.add_(delta_W.to(linear.weight.dtype))
|
||||
merged += 1
|
||||
|
||||
print(f"[PrismAudio] LoRA merged {merged} layer(s) (strength={strength:.3f})", flush=True)
|
||||
|
||||
|
||||
class PrismAudioLoRALoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"lora_path": ("STRING", {"default": "", "tooltip": "Path to .safetensors LoRA file produced by PrismAudio LoRA Trainer"}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "LoRA influence scale. 1.0 = full strength, 0.0 = base model only"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_lora"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_lora(self, model, lora_path, strength):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if not os.path.exists(lora_path):
|
||||
raise FileNotFoundError(f"[PrismAudio] LoRA file not found: {lora_path}")
|
||||
|
||||
config_path = lora_path.replace(".safetensors", "_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(
|
||||
f"[PrismAudio] LoRA config not found: {config_path}\n"
|
||||
"Expected a _config.json alongside the .safetensors file."
|
||||
)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
rank = config["rank"]
|
||||
alpha = config["alpha"]
|
||||
|
||||
lora_state = load_file(lora_path)
|
||||
|
||||
# Merge LoRA weights in-place into the DiT's base linear layers.
|
||||
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
||||
# when inputs change, providing a fresh base model as needed.
|
||||
dit = model["model"].model # DiffusionTransformer
|
||||
|
||||
if strength == 0.0:
|
||||
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
||||
return (model,)
|
||||
|
||||
_merge_lora_weights(dit, lora_state, rank, alpha, strength)
|
||||
|
||||
return (model,)
|
||||
@@ -0,0 +1,280 @@
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE,
|
||||
get_device, get_offload_device, soft_empty_cache,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoRA primitives
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""Low-rank adapter wrapping a frozen nn.Linear."""
|
||||
|
||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.scale = alpha / rank
|
||||
in_f, out_f = linear.in_features, linear.out_features
|
||||
self.lora_A = nn.Linear(in_f, rank, bias=False)
|
||||
self.lora_B = nn.Linear(rank, out_f, bias=False)
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale
|
||||
|
||||
|
||||
_TARGET_MODULE_PRESETS = {
|
||||
"attn_only": {"to_q", "to_kv", "to_qkv", "to_out"},
|
||||
"attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"},
|
||||
"full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"},
|
||||
}
|
||||
|
||||
|
||||
def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float):
|
||||
"""Recursively replace matching nn.Linear layers with LoRALinear."""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, nn.Linear) and name in target_attrs:
|
||||
setattr(module, name, LoRALinear(child, rank, alpha))
|
||||
else:
|
||||
_apply_lora(child, target_attrs, rank, alpha)
|
||||
|
||||
|
||||
def _unapply_lora(module: nn.Module):
|
||||
"""Replace LoRALinear back with the original frozen Linear (no weight merge)."""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, LoRALinear):
|
||||
child.linear.weight.requires_grad_(False)
|
||||
setattr(module, name, child.linear)
|
||||
else:
|
||||
_unapply_lora(child)
|
||||
|
||||
|
||||
def _get_lora_state_dict(module: nn.Module) -> dict:
|
||||
"""Return only LoRA parameter tensors from a module's state dict."""
|
||||
return {k: v for k, v in module.state_dict().items()
|
||||
if "lora_A" in k or "lora_B" in k}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_AUDIO_EXTS = (".wav", ".flac", ".mp3")
|
||||
|
||||
|
||||
def _scan_dataset(dataset_dir: str):
|
||||
"""Return list of (npz_path, audio_path) pairs matched by stem."""
|
||||
pairs = []
|
||||
for fname in os.listdir(dataset_dir):
|
||||
if not fname.endswith(".npz"):
|
||||
continue
|
||||
stem = os.path.join(dataset_dir, fname[:-4])
|
||||
for ext in _AUDIO_EXTS:
|
||||
audio_path = stem + ext
|
||||
if os.path.exists(audio_path):
|
||||
pairs.append((stem + ".npz", audio_path))
|
||||
break
|
||||
return sorted(pairs)
|
||||
|
||||
|
||||
def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor:
|
||||
"""Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE."""
|
||||
import torchaudio
|
||||
waveform, sr = torchaudio.load(audio_path)
|
||||
if sr != SAMPLE_RATE:
|
||||
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
|
||||
if waveform.shape[0] == 1:
|
||||
waveform = waveform.expand(2, -1)
|
||||
elif waveform.shape[0] > 2:
|
||||
waveform = waveform[:2]
|
||||
return waveform.unsqueeze(0).to(device) # [1, 2, samples]
|
||||
|
||||
|
||||
def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict:
|
||||
"""Load .npz features into a conditioner metadata dict."""
|
||||
import numpy as np
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype)
|
||||
text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype)
|
||||
sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype)
|
||||
has_video = bool(video_feat.abs().sum() > 0)
|
||||
return {
|
||||
"video_features": video_feat,
|
||||
"text_features": text_feat,
|
||||
"sync_features": sync_feat,
|
||||
"video_exist": torch.tensor(has_video),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trainer node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PrismAudioLoRATrainer:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}),
|
||||
"output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}),
|
||||
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}),
|
||||
"lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}),
|
||||
"target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}),
|
||||
"learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}),
|
||||
"train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}),
|
||||
"cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}),
|
||||
"save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("lora_path",)
|
||||
FUNCTION = "train"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha,
|
||||
target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed):
|
||||
from safetensors.torch import save_file
|
||||
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
diffusion = model["model"]
|
||||
strategy = model["strategy"]
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Scan dataset
|
||||
pairs = _scan_dataset(dataset_dir)
|
||||
if not pairs:
|
||||
raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}")
|
||||
print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True)
|
||||
|
||||
# Resolve output path
|
||||
if not output_path:
|
||||
import folder_paths
|
||||
out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors")
|
||||
|
||||
# Move model to device
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||
dit = diffusion.model # DiffusionTransformer
|
||||
for p in dit.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
target_attrs = _TARGET_MODULE_PRESETS[target_modules]
|
||||
_apply_lora(dit, target_attrs, lora_rank, lora_alpha)
|
||||
|
||||
# Cast LoRA params to model dtype and move to device
|
||||
for m in dit.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_A.to(device=device, dtype=dtype)
|
||||
m.lora_B.to(device=device, dtype=dtype)
|
||||
|
||||
trainable = [p for p in dit.parameters() if p.requires_grad]
|
||||
n_params = sum(p.numel() for p in trainable)
|
||||
print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True)
|
||||
|
||||
diffusion.conditioner.eval()
|
||||
diffusion.pretransform.eval()
|
||||
dit.train()
|
||||
|
||||
optimizer = torch.optim.AdamW(trainable, lr=learning_rate)
|
||||
|
||||
# GradScaler for fp16 to prevent underflow
|
||||
use_scaler = (dtype == torch.float16)
|
||||
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
||||
|
||||
pbar = comfy.utils.ProgressBar(train_steps)
|
||||
|
||||
for step in range(1, train_steps + 1):
|
||||
npz_path, audio_path = random.choice(pairs)
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode audio to latent space
|
||||
audio = _load_audio(audio_path, device)
|
||||
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||
|
||||
# Build conditioning from features
|
||||
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Rectified flow: interpolate between data and noise
|
||||
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||
noise = torch.randn_like(x0)
|
||||
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||
t_bcast = t[:, None, None]
|
||||
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||
v_target = noise - x0
|
||||
|
||||
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
v_pred = dit(x_t, t,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
**cond_inputs)
|
||||
|
||||
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||
|
||||
if use_scaler:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % 50 == 0:
|
||||
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||
|
||||
if step % save_every == 0:
|
||||
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Save final weights
|
||||
save_file(_get_lora_state_dict(dit), output_path)
|
||||
|
||||
# Save config alongside weights so the loader knows the structure
|
||||
config_path = output_path.replace(".safetensors", "_config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump({
|
||||
"rank": lora_rank,
|
||||
"alpha": lora_alpha,
|
||||
"target_modules": sorted(target_attrs),
|
||||
}, f, indent=2)
|
||||
|
||||
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||
|
||||
# Restore model to base state (remove LoRA wrappers, restore original linears)
|
||||
dit.eval()
|
||||
_unapply_lora(dit)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
return (output_path,)
|
||||
Reference in New Issue
Block a user