feat: LoRA trainer and loader nodes for PrismAudio DiT fine-tuning
Adds PrismAudioLoRATrainer and PrismAudioLoRALoader nodes enabling low-rank adaptation of the DiT on paired (video features + audio) datasets. - LoRALinear wraps nn.Linear with trainable lora_A/lora_B matrices - Rectified flow training loop with fp16 GradScaler, AdamW, cfg dropout - Checkpoint saving every N steps + _config.json metadata alongside weights - _unapply_lora restores base model state after training completes - Weight-merge loader: delta_W added in-place, no deep copy overhead - Three target presets: attn_only, attn_ffn (default), full Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,8 @@ _NODES = {
|
|||||||
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
||||||
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
||||||
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
||||||
|
"PrismAudioLoRATrainer": (".lora_trainer", "PrismAudioLoRATrainer", "PrismAudio LoRA Trainer"),
|
||||||
|
"PrismAudioLoRALoader": (".lora_loader", "PrismAudioLoRALoader", "PrismAudio LoRA Loader"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_lora_weights(dit: nn.Module, lora_state: dict, rank: int, alpha: float, strength: float):
|
||||||
|
"""Add LoRA delta weights directly into the base model's nn.Linear tensors.
|
||||||
|
|
||||||
|
delta_W = lora_B @ lora_A * scale * strength
|
||||||
|
applied as: linear.weight += delta_W
|
||||||
|
|
||||||
|
This is equivalent to LoRALinear at inference but requires no wrapper,
|
||||||
|
no extra memory, and no change to the model's forward call graph.
|
||||||
|
"""
|
||||||
|
scale = (alpha / rank) * strength
|
||||||
|
|
||||||
|
# Group saved keys by module path
|
||||||
|
a_map = {
|
||||||
|
k.replace(".lora_A.weight", ""): v
|
||||||
|
for k, v in lora_state.items() if k.endswith("lora_A.weight")
|
||||||
|
}
|
||||||
|
b_map = {
|
||||||
|
k.replace(".lora_B.weight", ""): v
|
||||||
|
for k, v in lora_state.items() if k.endswith("lora_B.weight")
|
||||||
|
}
|
||||||
|
|
||||||
|
merged = 0
|
||||||
|
for path, lora_A in a_map.items():
|
||||||
|
if path not in b_map:
|
||||||
|
print(f"[PrismAudio] LoRA merge: missing lora_B for {path}, skipping", flush=True)
|
||||||
|
continue
|
||||||
|
lora_B = b_map[path] # [out_features, rank]
|
||||||
|
# delta_W: [out_features, in_features]
|
||||||
|
delta_W = (lora_B.float() @ lora_A.float()) * scale
|
||||||
|
|
||||||
|
# Navigate to the parent module using PyTorch's get_submodule
|
||||||
|
*parent_parts, child_name = path.split(".")
|
||||||
|
try:
|
||||||
|
parent = dit.get_submodule(".".join(parent_parts)) if parent_parts else dit
|
||||||
|
except AttributeError as e:
|
||||||
|
print(f"[PrismAudio] LoRA merge: could not find module '{path}': {e}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
linear = getattr(parent, child_name, None)
|
||||||
|
if not isinstance(linear, nn.Linear):
|
||||||
|
print(f"[PrismAudio] LoRA merge: expected nn.Linear at '{path}', got {type(linear)}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
linear.weight.data.add_(delta_W.to(linear.weight.dtype))
|
||||||
|
merged += 1
|
||||||
|
|
||||||
|
print(f"[PrismAudio] LoRA merged {merged} layer(s) (strength={strength:.3f})", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioLoRALoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"lora_path": ("STRING", {"default": "", "tooltip": "Path to .safetensors LoRA file produced by PrismAudio LoRA Trainer"}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "LoRA influence scale. 1.0 = full strength, 0.0 = base model only"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_lora"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_lora(self, model, lora_path, strength):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
if not os.path.exists(lora_path):
|
||||||
|
raise FileNotFoundError(f"[PrismAudio] LoRA file not found: {lora_path}")
|
||||||
|
|
||||||
|
config_path = lora_path.replace(".safetensors", "_config.json")
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[PrismAudio] LoRA config not found: {config_path}\n"
|
||||||
|
"Expected a _config.json alongside the .safetensors file."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
rank = config["rank"]
|
||||||
|
alpha = config["alpha"]
|
||||||
|
|
||||||
|
lora_state = load_file(lora_path)
|
||||||
|
|
||||||
|
# Merge LoRA weights in-place into the DiT's base linear layers.
|
||||||
|
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
||||||
|
# when inputs change, providing a fresh base model as needed.
|
||||||
|
dit = model["model"].model # DiffusionTransformer
|
||||||
|
|
||||||
|
if strength == 0.0:
|
||||||
|
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
_merge_lora_weights(dit, lora_state, rank, alpha, strength)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
@@ -0,0 +1,280 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE,
|
||||||
|
get_device, get_offload_device, soft_empty_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoRA primitives
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
"""Low-rank adapter wrapping a frozen nn.Linear."""
|
||||||
|
|
||||||
|
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = linear
|
||||||
|
self.scale = alpha / rank
|
||||||
|
in_f, out_f = linear.in_features, linear.out_features
|
||||||
|
self.lora_A = nn.Linear(in_f, rank, bias=False)
|
||||||
|
self.lora_B = nn.Linear(rank, out_f, bias=False)
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
_TARGET_MODULE_PRESETS = {
|
||||||
|
"attn_only": {"to_q", "to_kv", "to_qkv", "to_out"},
|
||||||
|
"attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"},
|
||||||
|
"full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float):
|
||||||
|
"""Recursively replace matching nn.Linear layers with LoRALinear."""
|
||||||
|
for name, child in list(module.named_children()):
|
||||||
|
if isinstance(child, nn.Linear) and name in target_attrs:
|
||||||
|
setattr(module, name, LoRALinear(child, rank, alpha))
|
||||||
|
else:
|
||||||
|
_apply_lora(child, target_attrs, rank, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _unapply_lora(module: nn.Module):
|
||||||
|
"""Replace LoRALinear back with the original frozen Linear (no weight merge)."""
|
||||||
|
for name, child in list(module.named_children()):
|
||||||
|
if isinstance(child, LoRALinear):
|
||||||
|
child.linear.weight.requires_grad_(False)
|
||||||
|
setattr(module, name, child.linear)
|
||||||
|
else:
|
||||||
|
_unapply_lora(child)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_state_dict(module: nn.Module) -> dict:
|
||||||
|
"""Return only LoRA parameter tensors from a module's state dict."""
|
||||||
|
return {k: v for k, v in module.state_dict().items()
|
||||||
|
if "lora_A" in k or "lora_B" in k}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_AUDIO_EXTS = (".wav", ".flac", ".mp3")
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_dataset(dataset_dir: str):
|
||||||
|
"""Return list of (npz_path, audio_path) pairs matched by stem."""
|
||||||
|
pairs = []
|
||||||
|
for fname in os.listdir(dataset_dir):
|
||||||
|
if not fname.endswith(".npz"):
|
||||||
|
continue
|
||||||
|
stem = os.path.join(dataset_dir, fname[:-4])
|
||||||
|
for ext in _AUDIO_EXTS:
|
||||||
|
audio_path = stem + ext
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
pairs.append((stem + ".npz", audio_path))
|
||||||
|
break
|
||||||
|
return sorted(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor:
|
||||||
|
"""Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE."""
|
||||||
|
import torchaudio
|
||||||
|
waveform, sr = torchaudio.load(audio_path)
|
||||||
|
if sr != SAMPLE_RATE:
|
||||||
|
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
|
||||||
|
if waveform.shape[0] == 1:
|
||||||
|
waveform = waveform.expand(2, -1)
|
||||||
|
elif waveform.shape[0] > 2:
|
||||||
|
waveform = waveform[:2]
|
||||||
|
return waveform.unsqueeze(0).to(device) # [1, 2, samples]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict:
|
||||||
|
"""Load .npz features into a conditioner metadata dict."""
|
||||||
|
import numpy as np
|
||||||
|
data = np.load(npz_path, allow_pickle=True)
|
||||||
|
video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype)
|
||||||
|
text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype)
|
||||||
|
sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype)
|
||||||
|
has_video = bool(video_feat.abs().sum() > 0)
|
||||||
|
return {
|
||||||
|
"video_features": video_feat,
|
||||||
|
"text_features": text_feat,
|
||||||
|
"sync_features": sync_feat,
|
||||||
|
"video_exist": torch.tensor(has_video),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trainer node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class PrismAudioLoRATrainer:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}),
|
||||||
|
"output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}),
|
||||||
|
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}),
|
||||||
|
"lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}),
|
||||||
|
"target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}),
|
||||||
|
"learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}),
|
||||||
|
"train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}),
|
||||||
|
"cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}),
|
||||||
|
"save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("lora_path",)
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha,
|
||||||
|
target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed):
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
diffusion = model["model"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Scan dataset
|
||||||
|
pairs = _scan_dataset(dataset_dir)
|
||||||
|
if not pairs:
|
||||||
|
raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}")
|
||||||
|
print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True)
|
||||||
|
|
||||||
|
# Resolve output path
|
||||||
|
if not output_path:
|
||||||
|
import folder_paths
|
||||||
|
out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora")
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors")
|
||||||
|
|
||||||
|
# Move model to device
|
||||||
|
diffusion.model.to(device)
|
||||||
|
diffusion.conditioner.to(device)
|
||||||
|
diffusion.pretransform.to(device)
|
||||||
|
|
||||||
|
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||||
|
dit = diffusion.model # DiffusionTransformer
|
||||||
|
for p in dit.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
target_attrs = _TARGET_MODULE_PRESETS[target_modules]
|
||||||
|
_apply_lora(dit, target_attrs, lora_rank, lora_alpha)
|
||||||
|
|
||||||
|
# Cast LoRA params to model dtype and move to device
|
||||||
|
for m in dit.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_A.to(device=device, dtype=dtype)
|
||||||
|
m.lora_B.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
trainable = [p for p in dit.parameters() if p.requires_grad]
|
||||||
|
n_params = sum(p.numel() for p in trainable)
|
||||||
|
print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True)
|
||||||
|
|
||||||
|
diffusion.conditioner.eval()
|
||||||
|
diffusion.pretransform.eval()
|
||||||
|
dit.train()
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(trainable, lr=learning_rate)
|
||||||
|
|
||||||
|
# GradScaler for fp16 to prevent underflow
|
||||||
|
use_scaler = (dtype == torch.float16)
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
||||||
|
|
||||||
|
pbar = comfy.utils.ProgressBar(train_steps)
|
||||||
|
|
||||||
|
for step in range(1, train_steps + 1):
|
||||||
|
npz_path, audio_path = random.choice(pairs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode audio to latent space
|
||||||
|
audio = _load_audio(audio_path, device)
|
||||||
|
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||||
|
|
||||||
|
# Build conditioning from features
|
||||||
|
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||||
|
conditioning = diffusion.conditioner(metadata, device)
|
||||||
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||||
|
|
||||||
|
# Rectified flow: interpolate between data and noise
|
||||||
|
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||||
|
t_bcast = t[:, None, None]
|
||||||
|
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||||
|
v_target = noise - x0
|
||||||
|
|
||||||
|
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||||
|
v_pred = dit(x_t, t,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob=cfg_dropout_prob,
|
||||||
|
**cond_inputs)
|
||||||
|
|
||||||
|
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||||
|
|
||||||
|
if use_scaler:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % 50 == 0:
|
||||||
|
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||||
|
|
||||||
|
if step % save_every == 0:
|
||||||
|
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||||
|
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||||
|
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Save final weights
|
||||||
|
save_file(_get_lora_state_dict(dit), output_path)
|
||||||
|
|
||||||
|
# Save config alongside weights so the loader knows the structure
|
||||||
|
config_path = output_path.replace(".safetensors", "_config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"rank": lora_rank,
|
||||||
|
"alpha": lora_alpha,
|
||||||
|
"target_modules": sorted(target_attrs),
|
||||||
|
}, f, indent=2)
|
||||||
|
|
||||||
|
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||||
|
|
||||||
|
# Restore model to base state (remove LoRA wrappers, restore original linears)
|
||||||
|
dit.eval()
|
||||||
|
_unapply_lora(dit)
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(get_offload_device())
|
||||||
|
diffusion.conditioner.to(get_offload_device())
|
||||||
|
diffusion.pretransform.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
return (output_path,)
|
||||||
@@ -9,3 +9,4 @@ descript-audio-codec
|
|||||||
vector-quantize-pytorch
|
vector-quantize-pytorch
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
torchaudio
|
||||||
|
|||||||
Reference in New Issue
Block a user