diff --git a/nodes/__init__.py b/nodes/__init__.py index 51182bc..7053341 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -7,6 +7,8 @@ _NODES = { "PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"), "PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"), "PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"), + "PrismAudioLoRATrainer": (".lora_trainer", "PrismAudioLoRATrainer", "PrismAudio LoRA Trainer"), + "PrismAudioLoRALoader": (".lora_loader", "PrismAudioLoRALoader", "PrismAudio LoRA Loader"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/lora_loader.py b/nodes/lora_loader.py new file mode 100644 index 0000000..f88d53e --- /dev/null +++ b/nodes/lora_loader.py @@ -0,0 +1,106 @@ +import os +import json +import torch +import torch.nn as nn + +from .utils import PRISMAUDIO_CATEGORY + + +def _merge_lora_weights(dit: nn.Module, lora_state: dict, rank: int, alpha: float, strength: float): + """Add LoRA delta weights directly into the base model's nn.Linear tensors. + + delta_W = lora_B @ lora_A * scale * strength + applied as: linear.weight += delta_W + + This is equivalent to LoRALinear at inference but requires no wrapper, + no extra memory, and no change to the model's forward call graph. + """ + scale = (alpha / rank) * strength + + # Group saved keys by module path + a_map = { + k.replace(".lora_A.weight", ""): v + for k, v in lora_state.items() if k.endswith("lora_A.weight") + } + b_map = { + k.replace(".lora_B.weight", ""): v + for k, v in lora_state.items() if k.endswith("lora_B.weight") + } + + merged = 0 + for path, lora_A in a_map.items(): + if path not in b_map: + print(f"[PrismAudio] LoRA merge: missing lora_B for {path}, skipping", flush=True) + continue + lora_B = b_map[path] # [out_features, rank] + # delta_W: [out_features, in_features] + delta_W = (lora_B.float() @ lora_A.float()) * scale + + # Navigate to the parent module using PyTorch's get_submodule + *parent_parts, child_name = path.split(".") + try: + parent = dit.get_submodule(".".join(parent_parts)) if parent_parts else dit + except AttributeError as e: + print(f"[PrismAudio] LoRA merge: could not find module '{path}': {e}", flush=True) + continue + + linear = getattr(parent, child_name, None) + if not isinstance(linear, nn.Linear): + print(f"[PrismAudio] LoRA merge: expected nn.Linear at '{path}', got {type(linear)}", flush=True) + continue + + linear.weight.data.add_(delta_W.to(linear.weight.dtype)) + merged += 1 + + print(f"[PrismAudio] LoRA merged {merged} layer(s) (strength={strength:.3f})", flush=True) + + +class PrismAudioLoRALoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("PRISMAUDIO_MODEL",), + "lora_path": ("STRING", {"default": "", "tooltip": "Path to .safetensors LoRA file produced by PrismAudio LoRA Trainer"}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "LoRA influence scale. 1.0 = full strength, 0.0 = base model only"}), + }, + } + + RETURN_TYPES = ("PRISMAUDIO_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_lora" + CATEGORY = PRISMAUDIO_CATEGORY + + def load_lora(self, model, lora_path, strength): + from safetensors.torch import load_file + + if not os.path.exists(lora_path): + raise FileNotFoundError(f"[PrismAudio] LoRA file not found: {lora_path}") + + config_path = lora_path.replace(".safetensors", "_config.json") + if not os.path.exists(config_path): + raise FileNotFoundError( + f"[PrismAudio] LoRA config not found: {config_path}\n" + "Expected a _config.json alongside the .safetensors file." + ) + + with open(config_path) as f: + config = json.load(f) + + rank = config["rank"] + alpha = config["alpha"] + + lora_state = load_file(lora_path) + + # Merge LoRA weights in-place into the DiT's base linear layers. + # ComfyUI re-executes the upstream ModelLoader on the next queue run + # when inputs change, providing a fresh base model as needed. + dit = model["model"].model # DiffusionTransformer + + if strength == 0.0: + print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True) + return (model,) + + _merge_lora_weights(dit, lora_state, rank, alpha, strength) + + return (model,) diff --git a/nodes/lora_trainer.py b/nodes/lora_trainer.py new file mode 100644 index 0000000..0756776 --- /dev/null +++ b/nodes/lora_trainer.py @@ -0,0 +1,280 @@ +import os +import math +import json +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import comfy.utils + +from .utils import ( + PRISMAUDIO_CATEGORY, SAMPLE_RATE, + get_device, get_offload_device, soft_empty_cache, +) + + +# --------------------------------------------------------------------------- +# LoRA primitives +# --------------------------------------------------------------------------- + +class LoRALinear(nn.Module): + """Low-rank adapter wrapping a frozen nn.Linear.""" + + def __init__(self, linear: nn.Linear, rank: int, alpha: float): + super().__init__() + self.linear = linear + self.scale = alpha / rank + in_f, out_f = linear.in_features, linear.out_features + self.lora_A = nn.Linear(in_f, rank, bias=False) + self.lora_B = nn.Linear(rank, out_f, bias=False) + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x): + return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale + + +_TARGET_MODULE_PRESETS = { + "attn_only": {"to_q", "to_kv", "to_qkv", "to_out"}, + "attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"}, + "full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"}, +} + + +def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float): + """Recursively replace matching nn.Linear layers with LoRALinear.""" + for name, child in list(module.named_children()): + if isinstance(child, nn.Linear) and name in target_attrs: + setattr(module, name, LoRALinear(child, rank, alpha)) + else: + _apply_lora(child, target_attrs, rank, alpha) + + +def _unapply_lora(module: nn.Module): + """Replace LoRALinear back with the original frozen Linear (no weight merge).""" + for name, child in list(module.named_children()): + if isinstance(child, LoRALinear): + child.linear.weight.requires_grad_(False) + setattr(module, name, child.linear) + else: + _unapply_lora(child) + + +def _get_lora_state_dict(module: nn.Module) -> dict: + """Return only LoRA parameter tensors from a module's state dict.""" + return {k: v for k, v in module.state_dict().items() + if "lora_A" in k or "lora_B" in k} + + +# --------------------------------------------------------------------------- +# Dataset helpers +# --------------------------------------------------------------------------- + +_AUDIO_EXTS = (".wav", ".flac", ".mp3") + + +def _scan_dataset(dataset_dir: str): + """Return list of (npz_path, audio_path) pairs matched by stem.""" + pairs = [] + for fname in os.listdir(dataset_dir): + if not fname.endswith(".npz"): + continue + stem = os.path.join(dataset_dir, fname[:-4]) + for ext in _AUDIO_EXTS: + audio_path = stem + ext + if os.path.exists(audio_path): + pairs.append((stem + ".npz", audio_path)) + break + return sorted(pairs) + + +def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor: + """Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE.""" + import torchaudio + waveform, sr = torchaudio.load(audio_path) + if sr != SAMPLE_RATE: + waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE) + if waveform.shape[0] == 1: + waveform = waveform.expand(2, -1) + elif waveform.shape[0] > 2: + waveform = waveform[:2] + return waveform.unsqueeze(0).to(device) # [1, 2, samples] + + +def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict: + """Load .npz features into a conditioner metadata dict.""" + import numpy as np + data = np.load(npz_path, allow_pickle=True) + video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype) + text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype) + sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype) + has_video = bool(video_feat.abs().sum() > 0) + return { + "video_features": video_feat, + "text_features": text_feat, + "sync_features": sync_feat, + "video_exist": torch.tensor(has_video), + } + + +# --------------------------------------------------------------------------- +# Trainer node +# --------------------------------------------------------------------------- + +class PrismAudioLoRATrainer: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("PRISMAUDIO_MODEL",), + "dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}), + "output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}), + "lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}), + "lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}), + "target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}), + "learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}), + "train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}), + "cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}), + "save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), + }, + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("lora_path",) + FUNCTION = "train" + CATEGORY = PRISMAUDIO_CATEGORY + + def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha, + target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed): + from safetensors.torch import save_file + + device = get_device() + dtype = model["dtype"] + diffusion = model["model"] + strategy = model["strategy"] + + torch.manual_seed(seed) + random.seed(seed) + + # Scan dataset + pairs = _scan_dataset(dataset_dir) + if not pairs: + raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}") + print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True) + + # Resolve output path + if not output_path: + import folder_paths + out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora") + os.makedirs(out_dir, exist_ok=True) + output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors") + + # Move model to device + diffusion.model.to(device) + diffusion.conditioner.to(device) + diffusion.pretransform.to(device) + + # Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B) + dit = diffusion.model # DiffusionTransformer + for p in dit.parameters(): + p.requires_grad_(False) + + target_attrs = _TARGET_MODULE_PRESETS[target_modules] + _apply_lora(dit, target_attrs, lora_rank, lora_alpha) + + # Cast LoRA params to model dtype and move to device + for m in dit.modules(): + if isinstance(m, LoRALinear): + m.lora_A.to(device=device, dtype=dtype) + m.lora_B.to(device=device, dtype=dtype) + + trainable = [p for p in dit.parameters() if p.requires_grad] + n_params = sum(p.numel() for p in trainable) + print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True) + + diffusion.conditioner.eval() + diffusion.pretransform.eval() + dit.train() + + optimizer = torch.optim.AdamW(trainable, lr=learning_rate) + + # GradScaler for fp16 to prevent underflow + use_scaler = (dtype == torch.float16) + scaler = torch.cuda.amp.GradScaler() if use_scaler else None + + pbar = comfy.utils.ProgressBar(train_steps) + + for step in range(1, train_steps + 1): + npz_path, audio_path = random.choice(pairs) + + with torch.no_grad(): + # Encode audio to latent space + audio = _load_audio(audio_path, device) + x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L] + + # Build conditioning from features + metadata = (_load_metadata(npz_path, device, dtype),) + conditioning = diffusion.conditioner(metadata, device) + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + + # Rectified flow: interpolate between data and noise + t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1] + noise = torch.randn_like(x0) + # t expanded for broadcast: [1] -> [1, 1, 1] + t_bcast = t[:, None, None] + x_t = (1.0 - t_bcast) * x0 + t_bcast * noise + v_target = noise - x0 + + with torch.amp.autocast(device_type=device.type, dtype=dtype): + v_pred = dit(x_t, t, + cfg_scale=1.0, + cfg_dropout_prob=cfg_dropout_prob, + **cond_inputs) + + loss = F.mse_loss(v_pred.float(), v_target.float()) + + if use_scaler: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if step % 50 == 0: + print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True) + + if step % save_every == 0: + ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors") + save_file(_get_lora_state_dict(dit), ckpt_path) + print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True) + + pbar.update(1) + + # Save final weights + save_file(_get_lora_state_dict(dit), output_path) + + # Save config alongside weights so the loader knows the structure + config_path = output_path.replace(".safetensors", "_config.json") + with open(config_path, "w") as f: + json.dump({ + "rank": lora_rank, + "alpha": lora_alpha, + "target_modules": sorted(target_attrs), + }, f, indent=2) + + print(f"[PrismAudio] LoRA saved: {output_path}", flush=True) + + # Restore model to base state (remove LoRA wrappers, restore original linears) + dit.eval() + _unapply_lora(dit) + + if strategy == "offload_to_cpu": + diffusion.model.to(get_offload_device()) + diffusion.conditioner.to(get_offload_device()) + diffusion.pretransform.to(get_offload_device()) + soft_empty_cache() + + return (output_path,) diff --git a/requirements.txt b/requirements.txt index 3b237d9..2c6fedc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ descript-audio-codec vector-quantize-pytorch scipy tqdm +torchaudio