feat: LoRA fine-tuning for SelVA generator
Teaches the model new/partial sound classes from custom video+audio pairs.
Only ~10 MB of adapter weights are trained vs ~4.4 GB for the full model.
selva_core/model/lora.py
LoRALinear: wraps nn.Linear with frozen base + trainable A/B matrices.
B initialised to zero → zero adapter contribution at init.
apply_lora(): walks named_modules, replaces matching nn.Linear in-place.
Default target: "attn.qkv" (all 21 SelfAttention QKV projections in
large_44k). Add "linear1" to also wrap post-attention output projections.
get_lora_state_dict() / load_lora() for ~10 MB save/load.
train_lora.py (standalone script, no ComfyUI dependency)
Data format: directory of video files + optional prompts.txt
("filename: description"). Falls back to directory name as prompt.
Pre-extracts features for all clips into RAM, then trains from those.
Training loop: encode audio→latent (need_vae_encoder=True), flow
matching MSE loss on velocity prediction, backward on LoRA params only.
Saves adapter_stepNNNNN.pt checkpoints + adapter_final.pt with metadata.
Key verified interfaces used:
encode_audio() → DiagonalGaussianDistribution; .mode().clone() required
normalize() is in-place
forward(latent, clip_f, sync_f, text_f, t) takes raw tensors
nodes/selva_lora_loader.py (SelVA LoRA Loader ComfyUI node)
Loads .pt adapter, deep-copies the generator, applies LoRA, loads weights.
strength param scales lora_B to adjust adapter contribution at inference.
Reads rank/alpha/target from embedded metadata if present.
Returns a patched SELVA_MODEL bundle for use with the existing Sampler.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
import copy
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY
|
||||
from selva_core.model.lora import apply_lora, load_lora
|
||||
|
||||
|
||||
class SelvaLoraLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("SELVA_MODEL",),
|
||||
"adapter_path": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
|
||||
}),
|
||||
"strength": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
|
||||
"tooltip": "Scale applied to all LoRA contributions. "
|
||||
"1.0 = full adapter strength. "
|
||||
"0.0 = effectively disables the adapter. "
|
||||
"Values above 1.0 exaggerate the effect.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SELVA_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = (
|
||||
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
|
||||
"The base model is not modified — a shallow copy of the model bundle is returned."
|
||||
)
|
||||
|
||||
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
|
||||
if not adapter_path.strip():
|
||||
raise ValueError("[SelVA LoRA] adapter_path is empty.")
|
||||
|
||||
# Resolve path: allow absolute or relative to ComfyUI base
|
||||
from pathlib import Path
|
||||
p = Path(adapter_path)
|
||||
if not p.is_absolute():
|
||||
p = Path(folder_paths.base_path) / p
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
|
||||
|
||||
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
|
||||
|
||||
# Support both raw state_dict and {state_dict, meta} formats
|
||||
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
meta = checkpoint.get("meta", {})
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
meta = {}
|
||||
|
||||
rank = int(meta.get("rank", 16))
|
||||
alpha = float(meta.get("alpha", float(rank)))
|
||||
target = list(meta.get("target", ["attn.qkv"]))
|
||||
|
||||
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
||||
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}",
|
||||
flush=True)
|
||||
|
||||
# Shallow-copy the model bundle so the original generator is not mutated
|
||||
patched = {**model}
|
||||
generator = copy.deepcopy(model["generator"])
|
||||
|
||||
n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target))
|
||||
if n == 0:
|
||||
raise RuntimeError(
|
||||
f"[SelVA LoRA] No layers matched target={target}. "
|
||||
"Check that the adapter was trained with the same target suffixes."
|
||||
)
|
||||
load_lora(generator, state_dict)
|
||||
|
||||
# Apply strength scaling: multiply all lora_B params by strength
|
||||
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
|
||||
if strength != 1.0:
|
||||
with torch.no_grad():
|
||||
for name, param in generator.named_parameters():
|
||||
if "lora_B" in name:
|
||||
param.mul_(strength)
|
||||
|
||||
generator.to(model["generator"].parameters().__next__().device)
|
||||
patched["generator"] = generator
|
||||
|
||||
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
|
||||
return (patched,)
|
||||
Reference in New Issue
Block a user