From 437c62b28ff49f942e9a50686a3a10191a73be92 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 14:38:46 +0200 Subject: [PATCH] feat: LoRA fine-tuning for SelVA generator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Teaches the model new/partial sound classes from custom video+audio pairs. Only ~10 MB of adapter weights are trained vs ~4.4 GB for the full model. selva_core/model/lora.py LoRALinear: wraps nn.Linear with frozen base + trainable A/B matrices. B initialised to zero → zero adapter contribution at init. apply_lora(): walks named_modules, replaces matching nn.Linear in-place. Default target: "attn.qkv" (all 21 SelfAttention QKV projections in large_44k). Add "linear1" to also wrap post-attention output projections. get_lora_state_dict() / load_lora() for ~10 MB save/load. train_lora.py (standalone script, no ComfyUI dependency) Data format: directory of video files + optional prompts.txt ("filename: description"). Falls back to directory name as prompt. Pre-extracts features for all clips into RAM, then trains from those. Training loop: encode audio→latent (need_vae_encoder=True), flow matching MSE loss on velocity prediction, backward on LoRA params only. Saves adapter_stepNNNNN.pt checkpoints + adapter_final.pt with metadata. Key verified interfaces used: encode_audio() → DiagonalGaussianDistribution; .mode().clone() required normalize() is in-place forward(latent, clip_f, sync_f, text_f, t) takes raw tensors nodes/selva_lora_loader.py (SelVA LoRA Loader ComfyUI node) Loads .pt adapter, deep-copies the generator, applies LoRA, loads weights. strength param scales lora_B to adjust adapter contribution at inference. Reads rank/alpha/target from embedded metadata if present. Returns a patched SELVA_MODEL bundle for use with the existing Sampler. Co-Authored-By: Claude Sonnet 4.6 --- nodes/__init__.py | 1 + nodes/selva_lora_loader.py | 93 ++++++++ selva_core/model/lora.py | 116 ++++++++++ train_lora.py | 419 +++++++++++++++++++++++++++++++++++++ 4 files changed, 629 insertions(+) create mode 100644 nodes/selva_lora_loader.py create mode 100644 selva_core/model/lora.py create mode 100644 train_lora.py diff --git a/nodes/__init__.py b/nodes/__init__.py index ed3b2b9..9ba2b0c 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -5,6 +5,7 @@ _NODES = { "SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"), "SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"), "SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"), + "SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_lora_loader.py b/nodes/selva_lora_loader.py new file mode 100644 index 0000000..e0fb54d --- /dev/null +++ b/nodes/selva_lora_loader.py @@ -0,0 +1,93 @@ +import copy +import torch +import folder_paths + +from .utils import SELVA_CATEGORY +from selva_core.model.lora import apply_lora, load_lora + + +class SelvaLoraLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "adapter_path": ("STRING", { + "default": "", + "tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.", + }), + "strength": ("FLOAT", { + "default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, + "tooltip": "Scale applied to all LoRA contributions. " + "1.0 = full adapter strength. " + "0.0 = effectively disables the adapter. " + "Values above 1.0 exaggerate the effect.", + }), + }, + } + + RETURN_TYPES = ("SELVA_MODEL",) + RETURN_NAMES = ("model",) + OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",) + FUNCTION = "load" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "Loads a LoRA adapter produced by train_lora.py and applies it to the generator. " + "The base model is not modified — a shallow copy of the model bundle is returned." + ) + + def load(self, model: dict, adapter_path: str, strength: float) -> tuple: + if not adapter_path.strip(): + raise ValueError("[SelVA LoRA] adapter_path is empty.") + + # Resolve path: allow absolute or relative to ComfyUI base + from pathlib import Path + p = Path(adapter_path) + if not p.is_absolute(): + p = Path(folder_paths.base_path) / p + if not p.exists(): + raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}") + + checkpoint = torch.load(str(p), map_location="cpu", weights_only=False) + + # Support both raw state_dict and {state_dict, meta} formats + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + meta = checkpoint.get("meta", {}) + else: + state_dict = checkpoint + meta = {} + + rank = int(meta.get("rank", 16)) + alpha = float(meta.get("alpha", float(rank))) + target = list(meta.get("target", ["attn.qkv"])) + + print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True) + print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}", + flush=True) + + # Shallow-copy the model bundle so the original generator is not mutated + patched = {**model} + generator = copy.deepcopy(model["generator"]) + + n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target)) + if n == 0: + raise RuntimeError( + f"[SelVA LoRA] No layers matched target={target}. " + "Check that the adapter was trained with the same target suffixes." + ) + load_lora(generator, state_dict) + + # Apply strength scaling: multiply all lora_B params by strength + # (lora_B is initialised to zero, so scaling A is equivalent but less clean) + if strength != 1.0: + with torch.no_grad(): + for name, param in generator.named_parameters(): + if "lora_B" in name: + param.mul_(strength) + + generator.to(model["generator"].parameters().__next__().device) + patched["generator"] = generator + + print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True) + return (patched,) diff --git a/selva_core/model/lora.py b/selva_core/model/lora.py new file mode 100644 index 0000000..bbf0132 --- /dev/null +++ b/selva_core/model/lora.py @@ -0,0 +1,116 @@ +""" +LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator. + +Usage: + from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora + + n = apply_lora(net_generator, rank=16, alpha=16.0) + print(f"Wrapped {n} linear layers with LoRA") + + # ... train only LoRA params ... + + torch.save(get_lora_state_dict(net_generator), "adapter.pt") + + # Later, at inference: + apply_lora(net_generator, rank=16, alpha=16.0) + load_lora(net_generator, torch.load("adapter.pt")) +""" + +import math + +import torch +import torch.nn as nn + + +class LoRALinear(nn.Module): + """nn.Linear with a frozen base weight and trainable low-rank A/B matrices. + + Output: base(x) + (x @ A.T @ B.T) * (alpha / rank) + + A is initialised with Kaiming uniform; B is initialised to zero so the + adapter contribution starts at zero and does not disturb pretrained behaviour. + """ + + def __init__(self, linear: nn.Linear, rank: int, alpha: float): + super().__init__() + in_f = linear.in_features + out_f = linear.out_features + + self.linear = linear + linear.weight.requires_grad_(False) + if linear.bias is not None: + linear.bias.requires_grad_(False) + + self.lora_A = nn.Parameter(torch.empty(rank, in_f)) + self.lora_B = nn.Parameter(torch.zeros(out_f, rank)) + self.scale = alpha / rank + + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale + + def extra_repr(self) -> str: + rank = self.lora_A.shape[0] + return (f"in={self.linear.in_features}, out={self.linear.out_features}, " + f"rank={rank}, scale={self.scale:.4f}") + + +def apply_lora( + model: nn.Module, + rank: int = 16, + alpha: float = None, + target_suffixes: tuple = ("attn.qkv",), +) -> int: + """Replace matching nn.Linear layers with LoRALinear in-place. + + Args: + model: The module to modify (typically net_generator). + rank: LoRA rank. + alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0). + target_suffixes: Tuple of module name suffixes to wrap. Default is + ("attn.qkv",) which targets all SelfAttention QKV + projections in the MM-DiT generator. + Add "linear1" to also wrap post-attention output projections. + + Returns: + Number of linear layers wrapped. + """ + if alpha is None: + alpha = float(rank) + + count = 0 + for name, module in list(model.named_modules()): + if not any(name.endswith(s) for s in target_suffixes): + continue + if not isinstance(module, nn.Linear): + continue + + parts = name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], LoRALinear(module, rank, alpha)) + count += 1 + + return count + + +def get_lora_state_dict(model: nn.Module) -> dict: + """Return a state dict containing only LoRA parameters (lora_A and lora_B).""" + return {k: v for k, v in model.state_dict().items() if "lora_" in k} + + +def load_lora(model: nn.Module, state_dict: dict) -> None: + """Load LoRA weights into a model that has already had apply_lora() called. + + Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model + parameters are not modified. + """ + missing, unexpected = model.load_state_dict(state_dict, strict=False) + bad = [k for k in unexpected if "lora_" not in k] + if bad: + print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}") + lora_missing = [k for k in missing if "lora_" in k] + if lora_missing: + print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}") diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..bc3ab4b --- /dev/null +++ b/train_lora.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +""" +LoRA fine-tuning for SelVA / MMAudio generator. + +Teaches the model new or partially-known sound classes from custom video+audio pairs. +Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model). + +Data layout: + data/my_sound/ + clip01.mp4 # video files — audio is extracted from the video track + clip02.mp4 + prompts.txt # optional: "clip01.mp4: description of the sound" + +If prompts.txt is absent, the directory name is used as the prompt for all clips. + +Usage: + python train_lora.py \\ + --data_dir data/my_sound \\ + --output_dir lora_output \\ + --variant large_44k \\ + --selva_dir /path/to/ComfyUI/models/selva \\ + --rank 16 --steps 2000 --lr 1e-4 +""" + +import argparse +import os +import sys +import math +import random +import json +from pathlib import Path + +import torch +import torch.nn.functional as F +import torchaudio +from torchvision.io import read_video + +sys.path.insert(0, os.path.dirname(__file__)) + +from selva_core.model.networks_generator import get_my_mmaudio +from selva_core.model.networks_video_enc import get_my_textsynch +from selva_core.model.utils.features_utils import FeaturesUtils +from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K +from selva_core.model.flow_matching import FlowMatching +from selva_core.model.lora import apply_lora, get_lora_state_dict + +# --------------------------------------------------------------------------- +# Constants (mirror selva_feature_extractor.py) +# --------------------------------------------------------------------------- + +_CLIP_SIZE = 384 +_SYNC_SIZE = 224 +_CLIP_FPS = 8 +_SYNC_FPS = 25 + +_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1) +_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1) + +_VARIANTS = { + "small_16k": ("generator_small_16k_sup_5.pth", "16k", True), + "small_44k": ("generator_small_44k_sup_5.pth", "44k", False), + "medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False), + "large_44k": ("generator_large_44k_sup_5.pth", "44k", False), +} + +_VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv"} + +# --------------------------------------------------------------------------- +# Data helpers +# --------------------------------------------------------------------------- + +def load_prompts(data_dir: Path) -> dict: + """Load filename → prompt from prompts.txt. Returns empty dict if absent.""" + p = data_dir / "prompts.txt" + if not p.exists(): + return {} + mapping = {} + for line in p.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if ":" in line: + fname, prompt = line.split(":", 1) + mapping[fname.strip()] = prompt.strip() + return mapping + + +def load_clip(path: Path, target_sr: int, duration: float): + """Load a video file. + + Returns: + video: [T, H, W, C] float32 [0, 1] + audio: [L] float32 [-1, 1], resampled and trimmed/padded to duration + source_fps: float + """ + video, audio, info = read_video(str(path), pts_unit="sec", output_format="THWC") + + source_fps = float(info.get("video_fps", 30.0)) + audio_fps = int(info.get("audio_fps", target_sr)) + + # Video → float32 [0, 1] + video = video.float() / 255.0 # [T, H, W, C] + + # Audio → mono float32 [-1, 1] + target_len = int(duration * target_sr) + if audio.numel() == 0: + audio_out = torch.zeros(target_len) + else: + # audio shape: (channels, samples) — torchvision returns float in [-1, 1] + if audio.dim() == 2: + audio = audio.mean(0) # stereo → mono + elif audio.dim() == 1: + pass + audio = audio.float() + + # Safety: clamp to [-1, 1] in case of PCM encoding + if audio.abs().max() > 1.0: + audio = audio / 32768.0 + + if audio_fps != target_sr: + audio = torchaudio.functional.resample( + audio.unsqueeze(0), audio_fps, target_sr + ).squeeze(0) + + if audio.shape[0] >= target_len: + audio_out = audio[:target_len] + else: + audio_out = F.pad(audio, (0, target_len - audio.shape[0])) + + return video, audio_out, source_fps + + +def _sample_frames(video, source_fps, target_fps, duration): + T = video.shape[0] + n_out = max(1, int(duration * target_fps)) + indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)] + return video[indices] + + +def _resize_frames(frames, size): + x = frames.permute(0, 3, 1, 2).float() # [N, C, H, W] + x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False) + return x.clamp(0.0, 1.0) + + +def extract_features(video, audio, source_fps, prompt, duration, + feature_utils, net_video_enc, device, dtype): + """Extract all conditioning features from a single video+audio clip. + + All returned tensors are on CPU, detached — ready to move to device for training. + """ + with torch.no_grad(): + # --- Audio latent (VAE encode) --- + # encode_audio is @inference_mode and returns DiagonalGaussianDistribution + audio_b = audio.unsqueeze(0).to(feature_utils.device, dtype) # [1, L] + dist = feature_utils.encode_audio(audio_b) + x1 = dist.mode().clone().cpu() # [1, seq_len, latent_dim] — .clone() exits inference mode + + # --- CLIP visual features --- + clip_frames = _sample_frames(video, source_fps, _CLIP_FPS, duration) + clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] + clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384] + clip_f = feature_utils.encode_video_with_clip(clip_input).cpu() # [1, N, 1024] + + # --- Sync (TextSynchformer) features --- + sync_frames = _sample_frames(video, source_fps, _SYNC_FPS, duration) + sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] + if sync_frames.shape[0] < 16: + pad = 16 - sync_frames.shape[0] + sync_frames = torch.cat( + [sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0) + mean = _SYNC_MEAN.to(sync_frames.device) + std = _SYNC_STD.to(sync_frames.device) + sync_frames = (sync_frames - mean) / std + sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224] + + text_t5, text_mask = feature_utils.encode_text_t5([prompt]) + text_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_t5, text_mask) + sync_f = net_video_enc.encode_video_with_sync( + sync_input, text_f=text_t5, text_mask=text_mask + ).cpu() # [1, T_sync, 768] + + # --- CLIP text features --- + text_clip = feature_utils.encode_text_clip([prompt]).cpu() # [1, 77, D] + + return x1, clip_f, sync_f, text_clip + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator") + parser.add_argument("--data_dir", required=True, help="Directory with video files and optional prompts.txt") + parser.add_argument("--output_dir", default="lora_output") + parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys())) + parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)") + parser.add_argument("--rank", type=int, default=16, help="LoRA rank") + parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha (default: rank)") + parser.add_argument("--target", nargs="+", default=["attn.qkv"], + help="Module name suffixes to wrap with LoRA. Also try 'linear1'.") + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--steps", type=int, default=2000) + parser.add_argument("--warmup_steps",type=int, default=500) + parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps") + parser.add_argument("--save_every", type=int, default=500) + parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + torch.manual_seed(args.seed) + random.seed(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if args.precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported(): + print("[LoRA] bf16 not supported on this GPU — falling back to fp16") + args.precision = "fp16" + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision] + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + selva_dir = Path(args.selva_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + gen_filename, mode, has_bigvgan = _VARIANTS[args.variant] + seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K + duration = seq_cfg.duration + sample_rate = seq_cfg.sampling_rate + + # --- Weight paths --- + def w(name): return str(selva_dir / name) + def wext(name): return str(selva_dir / "ext" / name) + + for path, label in [ + (w("video_enc_sup_5.pth"), "video_enc"), + (w(gen_filename), "generator"), + (wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), "VAE"), + ]: + if not Path(path).exists(): + print(f"[LoRA] Missing weight: {path} ({label})") + print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.") + sys.exit(1) + + synch_path = str(selva_dir / "synchformer_state_dict.pth") + if not Path(synch_path).exists(): + # Fallback: check prismaudio dir + alt = selva_dir.parent / "prismaudio" / "synchformer_state_dict.pth" + if alt.exists(): + synch_path = str(alt) + else: + print(f"[LoRA] Missing synchformer weights: {synch_path}") + sys.exit(1) + + bigvgan_path = wext("best_netG.pt") if has_bigvgan else None + + # --- Load models --- + print(f"[LoRA] Loading TextSynch encoder...") + net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval() + net_video_enc.load_weights( + torch.load(w("video_enc_sup_5.pth"), map_location="cpu", weights_only=False) + ) + + print(f"[LoRA] Loading generator ({args.variant})...") + net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval() + net_generator.load_weights( + torch.load(w(gen_filename), map_location="cpu", weights_only=False) + ) + + print("[LoRA] Loading FeaturesUtils (need_vae_encoder=True)...") + feature_utils = FeaturesUtils( + tod_vae_ckpt=wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), + synchformer_ckpt=synch_path, + enable_conditions=True, + mode=mode, + bigvgan_vocoder_ckpt=bigvgan_path, + need_vae_encoder=True, # required for audio → latent encoding during training + ).to(device, dtype).eval() + + # --- Apply LoRA --- + n_lora = apply_lora( + net_generator, + rank=args.rank, + alpha=args.alpha, + target_suffixes=tuple(args.target), + ) + print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target})") + if n_lora == 0: + print("[LoRA] ERROR: no layers were wrapped — check --target names.") + sys.exit(1) + + # Freeze everything except LoRA params + for name, p in net_generator.named_parameters(): + p.requires_grad_("lora_" in name) + + trainable = sum(p.numel() for p in net_generator.parameters() if p.requires_grad) + total = sum(p.numel() for p in net_generator.parameters()) + print(f"[LoRA] Trainable: {trainable:,} / {total:,} params " + f"({100 * trainable / total:.2f}%)") + + # Update rotary position embeddings for the fixed sequence lengths + net_generator.update_seq_lengths( + latent_seq_len=seq_cfg.latent_seq_len, + clip_seq_len=seq_cfg.clip_seq_len, + sync_seq_len=seq_cfg.sync_seq_len, + ) + + # --- Dataset --- + video_files = sorted( + p for p in data_dir.iterdir() + if p.suffix.lower() in _VIDEO_EXTS + ) + if not video_files: + print(f"[LoRA] No video files found in {data_dir}") + sys.exit(1) + print(f"[LoRA] Found {len(video_files)} video(s) in {data_dir}") + + prompt_map = load_prompts(data_dir) + default_prompt = data_dir.name # use directory name as fallback prompt + + # Pre-extract features for all clips (cache in RAM) + print("[LoRA] Extracting features from all clips...") + dataset = [] + for vf in video_files: + prompt = prompt_map.get(vf.name, default_prompt) + print(f" {vf.name}: '{prompt}'") + try: + video, audio, source_fps = load_clip(vf, sample_rate, duration) + x1, clip_f, sync_f, text_clip = extract_features( + video, audio, source_fps, prompt, duration, + feature_utils, net_video_enc, device, dtype, + ) + dataset.append((x1, clip_f, sync_f, text_clip)) + except Exception as e: + print(f" [LoRA] Warning: failed to process {vf.name}: {e}") + if not dataset: + print("[LoRA] No clips could be loaded.") + sys.exit(1) + print(f"[LoRA] {len(dataset)} clips ready.") + + # --- Optimizer + LR scheduler --- + lora_params = [p for p in net_generator.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2) + + def lr_lambda(step): + if step < args.warmup_steps: + return step / max(1, args.warmup_steps) + return 1.0 + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) + + # --- Training loop --- + net_generator.train() + optimizer.zero_grad() + + print(f"\n[LoRA] Training: {args.steps} steps, lr={args.lr}, grad_accum={args.grad_accum}") + print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n") + + total_loss = 0.0 + for step in range(1, args.steps + 1): + # Sample a random clip from the dataset + x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) + + x1 = x1_cpu.to(device, dtype) + clip_f = clip_f_cpu.to(device, dtype) + sync_f = sync_f_cpu.to(device, dtype) + text_clip = text_clip_cpu.to(device, dtype) + + # Normalize latent in-place (net_generator.normalize is in-place) + net_generator.normalize(x1) + + # Flow matching step + t = torch.rand(1, device=device, dtype=dtype) # (1,) — one timestep + x0 = torch.randn_like(x1) + xt = fm.get_conditional_flow(x0, x1, t) + + # Forward pass — gradients flow through LoRA A/B only + # forward(latent, clip_f, sync_f, text_f, t) takes raw feature tensors + v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t) + + loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum + loss.backward() + total_loss += loss.item() * args.grad_accum + + if step % args.grad_accum == 0: + torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + if step % 50 == 0: + avg = total_loss / 50 + lr_now = scheduler.get_last_lr()[0] + print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}") + total_loss = 0.0 + + if step % args.save_every == 0 or step == args.steps: + ckpt = output_dir / f"adapter_step{step:05d}.pt" + torch.save(get_lora_state_dict(net_generator), ckpt) + print(f"[LoRA] Saved {ckpt}") + + # Save final adapter with metadata + final = output_dir / "adapter_final.pt" + meta = { + "variant": args.variant, + "rank": args.rank, + "alpha": args.alpha if args.alpha is not None else float(args.rank), + "target": args.target, + "steps": args.steps, + } + torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final) + (output_dir / "meta.json").write_text(json.dumps(meta, indent=2)) + print(f"\n[LoRA] Training complete. Adapter saved to {final}") + + +if __name__ == "__main__": + main()