feat: LoRA fine-tuning for SelVA generator
Teaches the model new/partial sound classes from custom video+audio pairs.
Only ~10 MB of adapter weights are trained vs ~4.4 GB for the full model.
selva_core/model/lora.py
LoRALinear: wraps nn.Linear with frozen base + trainable A/B matrices.
B initialised to zero → zero adapter contribution at init.
apply_lora(): walks named_modules, replaces matching nn.Linear in-place.
Default target: "attn.qkv" (all 21 SelfAttention QKV projections in
large_44k). Add "linear1" to also wrap post-attention output projections.
get_lora_state_dict() / load_lora() for ~10 MB save/load.
train_lora.py (standalone script, no ComfyUI dependency)
Data format: directory of video files + optional prompts.txt
("filename: description"). Falls back to directory name as prompt.
Pre-extracts features for all clips into RAM, then trains from those.
Training loop: encode audio→latent (need_vae_encoder=True), flow
matching MSE loss on velocity prediction, backward on LoRA params only.
Saves adapter_stepNNNNN.pt checkpoints + adapter_final.pt with metadata.
Key verified interfaces used:
encode_audio() → DiagonalGaussianDistribution; .mode().clone() required
normalize() is in-place
forward(latent, clip_f, sync_f, text_f, t) takes raw tensors
nodes/selva_lora_loader.py (SelVA LoRA Loader ComfyUI node)
Loads .pt adapter, deep-copies the generator, applies LoRA, loads weights.
strength param scales lora_B to adjust adapter contribution at inference.
Reads rank/alpha/target from embedded metadata if present.
Returns a patched SELVA_MODEL bundle for use with the existing Sampler.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
|
||||
|
||||
Usage:
|
||||
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
||||
|
||||
n = apply_lora(net_generator, rank=16, alpha=16.0)
|
||||
print(f"Wrapped {n} linear layers with LoRA")
|
||||
|
||||
# ... train only LoRA params ...
|
||||
|
||||
torch.save(get_lora_state_dict(net_generator), "adapter.pt")
|
||||
|
||||
# Later, at inference:
|
||||
apply_lora(net_generator, rank=16, alpha=16.0)
|
||||
load_lora(net_generator, torch.load("adapter.pt"))
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
||||
|
||||
Output: base(x) + (x @ A.T @ B.T) * (alpha / rank)
|
||||
|
||||
A is initialised with Kaiming uniform; B is initialised to zero so the
|
||||
adapter contribution starts at zero and does not disturb pretrained behaviour.
|
||||
"""
|
||||
|
||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||
super().__init__()
|
||||
in_f = linear.in_features
|
||||
out_f = linear.out_features
|
||||
|
||||
self.linear = linear
|
||||
linear.weight.requires_grad_(False)
|
||||
if linear.bias is not None:
|
||||
linear.bias.requires_grad_(False)
|
||||
|
||||
self.lora_A = nn.Parameter(torch.empty(rank, in_f))
|
||||
self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
|
||||
self.scale = alpha / rank
|
||||
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
rank = self.lora_A.shape[0]
|
||||
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
|
||||
f"rank={rank}, scale={self.scale:.4f}")
|
||||
|
||||
|
||||
def apply_lora(
|
||||
model: nn.Module,
|
||||
rank: int = 16,
|
||||
alpha: float = None,
|
||||
target_suffixes: tuple = ("attn.qkv",),
|
||||
) -> int:
|
||||
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
||||
|
||||
Args:
|
||||
model: The module to modify (typically net_generator).
|
||||
rank: LoRA rank.
|
||||
alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0).
|
||||
target_suffixes: Tuple of module name suffixes to wrap. Default is
|
||||
("attn.qkv",) which targets all SelfAttention QKV
|
||||
projections in the MM-DiT generator.
|
||||
Add "linear1" to also wrap post-attention output projections.
|
||||
|
||||
Returns:
|
||||
Number of linear layers wrapped.
|
||||
"""
|
||||
if alpha is None:
|
||||
alpha = float(rank)
|
||||
|
||||
count = 0
|
||||
for name, module in list(model.named_modules()):
|
||||
if not any(name.endswith(s) for s in target_suffixes):
|
||||
continue
|
||||
if not isinstance(module, nn.Linear):
|
||||
continue
|
||||
|
||||
parts = name.split(".")
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part)
|
||||
setattr(parent, parts[-1], LoRALinear(module, rank, alpha))
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def get_lora_state_dict(model: nn.Module) -> dict:
|
||||
"""Return a state dict containing only LoRA parameters (lora_A and lora_B)."""
|
||||
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
|
||||
|
||||
|
||||
def load_lora(model: nn.Module, state_dict: dict) -> None:
|
||||
"""Load LoRA weights into a model that has already had apply_lora() called.
|
||||
|
||||
Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model
|
||||
parameters are not modified.
|
||||
"""
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
bad = [k for k in unexpected if "lora_" not in k]
|
||||
if bad:
|
||||
print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}")
|
||||
lora_missing = [k for k in missing if "lora_" in k]
|
||||
if lora_missing:
|
||||
print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}")
|
||||
Reference in New Issue
Block a user