437c62b28f
Teaches the model new/partial sound classes from custom video+audio pairs.
Only ~10 MB of adapter weights are trained vs ~4.4 GB for the full model.
selva_core/model/lora.py
LoRALinear: wraps nn.Linear with frozen base + trainable A/B matrices.
B initialised to zero → zero adapter contribution at init.
apply_lora(): walks named_modules, replaces matching nn.Linear in-place.
Default target: "attn.qkv" (all 21 SelfAttention QKV projections in
large_44k). Add "linear1" to also wrap post-attention output projections.
get_lora_state_dict() / load_lora() for ~10 MB save/load.
train_lora.py (standalone script, no ComfyUI dependency)
Data format: directory of video files + optional prompts.txt
("filename: description"). Falls back to directory name as prompt.
Pre-extracts features for all clips into RAM, then trains from those.
Training loop: encode audio→latent (need_vae_encoder=True), flow
matching MSE loss on velocity prediction, backward on LoRA params only.
Saves adapter_stepNNNNN.pt checkpoints + adapter_final.pt with metadata.
Key verified interfaces used:
encode_audio() → DiagonalGaussianDistribution; .mode().clone() required
normalize() is in-place
forward(latent, clip_f, sync_f, text_f, t) takes raw tensors
nodes/selva_lora_loader.py (SelVA LoRA Loader ComfyUI node)
Loads .pt adapter, deep-copies the generator, applies LoRA, loads weights.
strength param scales lora_B to adjust adapter contribution at inference.
Reads rank/alpha/target from embedded metadata if present.
Returns a patched SELVA_MODEL bundle for use with the existing Sampler.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
"""
|
|
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
|
|
|
|
Usage:
|
|
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
|
|
|
n = apply_lora(net_generator, rank=16, alpha=16.0)
|
|
print(f"Wrapped {n} linear layers with LoRA")
|
|
|
|
# ... train only LoRA params ...
|
|
|
|
torch.save(get_lora_state_dict(net_generator), "adapter.pt")
|
|
|
|
# Later, at inference:
|
|
apply_lora(net_generator, rank=16, alpha=16.0)
|
|
load_lora(net_generator, torch.load("adapter.pt"))
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class LoRALinear(nn.Module):
|
|
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
|
|
|
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank)
|
|
|
|
A is initialised with Kaiming uniform; B is initialised to zero so the
|
|
adapter contribution starts at zero and does not disturb pretrained behaviour.
|
|
"""
|
|
|
|
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
|
super().__init__()
|
|
in_f = linear.in_features
|
|
out_f = linear.out_features
|
|
|
|
self.linear = linear
|
|
linear.weight.requires_grad_(False)
|
|
if linear.bias is not None:
|
|
linear.bias.requires_grad_(False)
|
|
|
|
self.lora_A = nn.Parameter(torch.empty(rank, in_f))
|
|
self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
|
|
self.scale = alpha / rank
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
|
|
|
|
def extra_repr(self) -> str:
|
|
rank = self.lora_A.shape[0]
|
|
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
|
|
f"rank={rank}, scale={self.scale:.4f}")
|
|
|
|
|
|
def apply_lora(
|
|
model: nn.Module,
|
|
rank: int = 16,
|
|
alpha: float = None,
|
|
target_suffixes: tuple = ("attn.qkv",),
|
|
) -> int:
|
|
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
|
|
|
Args:
|
|
model: The module to modify (typically net_generator).
|
|
rank: LoRA rank.
|
|
alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0).
|
|
target_suffixes: Tuple of module name suffixes to wrap. Default is
|
|
("attn.qkv",) which targets all SelfAttention QKV
|
|
projections in the MM-DiT generator.
|
|
Add "linear1" to also wrap post-attention output projections.
|
|
|
|
Returns:
|
|
Number of linear layers wrapped.
|
|
"""
|
|
if alpha is None:
|
|
alpha = float(rank)
|
|
|
|
count = 0
|
|
for name, module in list(model.named_modules()):
|
|
if not any(name.endswith(s) for s in target_suffixes):
|
|
continue
|
|
if not isinstance(module, nn.Linear):
|
|
continue
|
|
|
|
parts = name.split(".")
|
|
parent = model
|
|
for part in parts[:-1]:
|
|
parent = getattr(parent, part)
|
|
setattr(parent, parts[-1], LoRALinear(module, rank, alpha))
|
|
count += 1
|
|
|
|
return count
|
|
|
|
|
|
def get_lora_state_dict(model: nn.Module) -> dict:
|
|
"""Return a state dict containing only LoRA parameters (lora_A and lora_B)."""
|
|
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
|
|
|
|
|
|
def load_lora(model: nn.Module, state_dict: dict) -> None:
|
|
"""Load LoRA weights into a model that has already had apply_lora() called.
|
|
|
|
Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model
|
|
parameters are not modified.
|
|
"""
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
bad = [k for k in unexpected if "lora_" not in k]
|
|
if bad:
|
|
print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}")
|
|
lora_missing = [k for k in missing if "lora_" in k]
|
|
if lora_missing:
|
|
print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}")
|