diff --git a/train_lora.py b/train_lora.py index f11c955..319c401 100644 --- a/train_lora.py +++ b/train_lora.py @@ -7,11 +7,12 @@ Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model) Data layout: data/my_sound/ - clip01.mp4 # video files — audio is extracted from the video track - clip02.mp4 - prompts.txt # optional: "clip01.mp4: description of the sound" + clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI + clip01.wav # paired clean audio (same filename stem, any format) + prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt -If prompts.txt is absent, the directory name is used as the prompt for all clips. +If prompts.txt is absent, the prompt embedded in each .npz is used. +If the .npz has no embedded prompt, the directory name is used as fallback. Usage: python train_lora.py \\ @@ -29,47 +30,40 @@ import random import json from pathlib import Path +import numpy as np import torch import torch.nn.functional as F import torchaudio -from torchvision.io import read_video +import open_clip +from open_clip import create_model_from_pretrained sys.path.insert(0, os.path.dirname(__file__)) from selva_core.model.networks_generator import get_my_mmaudio -from selva_core.model.networks_video_enc import get_my_textsynch -from selva_core.model.utils.features_utils import FeaturesUtils +from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K from selva_core.model.flow_matching import FlowMatching from selva_core.model.lora import apply_lora, get_lora_state_dict # --------------------------------------------------------------------------- -# Constants (mirror selva_feature_extractor.py) +# Constants # --------------------------------------------------------------------------- -_CLIP_SIZE = 384 -_SYNC_SIZE = 224 -_CLIP_FPS = 8 -_SYNC_FPS = 25 - -_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1) -_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1) - _VARIANTS = { - "small_16k": ("generator_small_16k_sup_5.pth", "16k", True), - "small_44k": ("generator_small_44k_sup_5.pth", "44k", False), - "medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False), - "large_44k": ("generator_large_44k_sup_5.pth", "44k", False), + "small_16k": ("generator_small_16k_sup_5.pth", "16k"), + "small_44k": ("generator_small_44k_sup_5.pth", "44k"), + "medium_44k": ("generator_medium_44k_sup_5.pth", "44k"), + "large_44k": ("generator_large_44k_sup_5.pth", "44k"), } -_VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv"} +_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"} # --------------------------------------------------------------------------- # Data helpers # --------------------------------------------------------------------------- def load_prompts(data_dir: Path) -> dict: - """Load filename → prompt from prompts.txt. Returns empty dict if absent.""" + """Load filename → prompt overrides from prompts.txt.""" p = data_dir / "prompts.txt" if not p.exists(): return {} @@ -84,105 +78,68 @@ def load_prompts(data_dir: Path) -> dict: return mapping -def load_clip(path: Path, target_sr: int, duration: float): - """Load a video file. +def find_audio_for_npz(npz_path: Path) -> Path | None: + """Find a paired audio file with the same stem as the .npz.""" + for ext in _AUDIO_EXTS: + candidate = npz_path.with_suffix(ext) + if candidate.exists(): + return candidate + return None - Returns: - video: [T, H, W, C] float32 [0, 1] - audio: [L] float32 [-1, 1], resampled and trimmed/padded to duration - source_fps: float - """ - video, audio, info = read_video(str(path), pts_unit="sec", output_format="THWC") - source_fps = float(info.get("video_fps", 30.0)) - audio_fps = int(info.get("audio_fps", target_sr)) +def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor: + """Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration.""" + waveform, sr = torchaudio.load(str(path)) - # Video → float32 [0, 1] - video = video.float() / 255.0 # [T, H, W, C] + # Stereo → mono + if waveform.shape[0] > 1: + waveform = waveform.mean(0, keepdim=True) + waveform = waveform.squeeze(0).float() + + # Resample + if sr != target_sr: + waveform = torchaudio.functional.resample( + waveform.unsqueeze(0), sr, target_sr + ).squeeze(0) - # Audio → mono float32 [-1, 1] target_len = int(duration * target_sr) - if audio.numel() == 0: - audio_out = torch.zeros(target_len) - else: - # audio shape: (channels, samples) — torchvision returns float in [-1, 1] - if audio.dim() == 2: - audio = audio.mean(0) # stereo → mono - elif audio.dim() == 1: - pass - audio = audio.float() - - # Safety: clamp to [-1, 1] in case of PCM encoding - if audio.abs().max() > 1.0: - audio = audio / 32768.0 - - if audio_fps != target_sr: - audio = torchaudio.functional.resample( - audio.unsqueeze(0), audio_fps, target_sr - ).squeeze(0) - - if audio.shape[0] >= target_len: - audio_out = audio[:target_len] - else: - audio_out = F.pad(audio, (0, target_len - audio.shape[0])) - - return video, audio_out, source_fps + if waveform.shape[0] >= target_len: + return waveform[:target_len] + return F.pad(waveform, (0, target_len - waveform.shape[0])) -def _sample_frames(video, source_fps, target_fps, duration): - T = video.shape[0] - n_out = max(1, int(duration * target_fps)) - indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)] - return video[indices] +def load_npz(path: Path) -> dict: + """Load a feature bundle produced by SelvaFeatureExtractor.""" + data = np.load(str(path), allow_pickle=False) + bundle = { + "clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024] + "sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768] + } + if "prompt" in data: + bundle["prompt"] = str(data["prompt"]) + if "variant" in data: + bundle["variant"] = str(data["variant"]) + return bundle -def _resize_frames(frames, size): - x = frames.permute(0, 3, 1, 2).float() # [N, C, H, W] - x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False) - return x.clamp(0.0, 1.0) +# --------------------------------------------------------------------------- +# Feature extraction (audio + text only — visual features come from .npz) +# --------------------------------------------------------------------------- + +def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor: + tokens = tokenizer(text).to(device) + with torch.inference_mode(): + return clip_model.encode_text(tokens, normalize=True) -def extract_features(video, audio, source_fps, prompt, duration, - feature_utils, net_video_enc, device, dtype): - """Extract all conditioning features from a single video+audio clip. +def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor: + """Encode a waveform to the generator's latent space via the VAE. - All returned tensors are on CPU, detached — ready to move to device for training. + encode_audio is @inference_mode — .clone() is required before the autograd path. """ - with torch.no_grad(): - # --- Audio latent (VAE encode) --- - # encode_audio is @inference_mode and returns DiagonalGaussianDistribution - audio_b = audio.unsqueeze(0).to(feature_utils.device, dtype) # [1, L] - dist = feature_utils.encode_audio(audio_b) - x1 = dist.mode().clone().cpu() # [1, seq_len, latent_dim] — .clone() exits inference mode - - # --- CLIP visual features --- - clip_frames = _sample_frames(video, source_fps, _CLIP_FPS, duration) - clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] - clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384] - clip_f = feature_utils.encode_video_with_clip(clip_input).cpu() # [1, N, 1024] - - # --- Sync (TextSynchformer) features --- - sync_frames = _sample_frames(video, source_fps, _SYNC_FPS, duration) - sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] - if sync_frames.shape[0] < 16: - pad = 16 - sync_frames.shape[0] - sync_frames = torch.cat( - [sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0) - mean = _SYNC_MEAN.to(sync_frames.device) - std = _SYNC_STD.to(sync_frames.device) - sync_frames = (sync_frames - mean) / std - sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224] - - text_t5, text_mask = feature_utils.encode_text_t5([prompt]) - text_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_t5, text_mask) - sync_f = net_video_enc.encode_video_with_sync( - sync_input, text_f=text_t5, text_mask=text_mask - ).cpu() # [1, T_sync, 768] - - # --- CLIP text features --- - text_clip = feature_utils.encode_text_clip([prompt]).cpu() # [1, 77, D] - - return x1, clip_f, sync_f, text_clip + audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L] + dist = feature_utils.encode_audio(audio_b) + return dist.mode().clone().cpu() # [1, seq_len, latent_dim] # --------------------------------------------------------------------------- @@ -191,7 +148,7 @@ def extract_features(video, audio, source_fps, prompt, duration, def main(): parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator") - parser.add_argument("--data_dir", required=True, help="Directory with video files and optional prompts.txt") + parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt") parser.add_argument("--output_dir", default="lora_output") parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys())) parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)") @@ -222,60 +179,47 @@ def main(): selva_dir = Path(args.selva_dir) output_dir.mkdir(parents=True, exist_ok=True) - gen_filename, mode, has_bigvgan = _VARIANTS[args.variant] - seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K - duration = seq_cfg.duration + gen_filename, mode = _VARIANTS[args.variant] + seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K + duration = seq_cfg.duration sample_rate = seq_cfg.sampling_rate # --- Weight paths --- def w(name): return str(selva_dir / name) def wext(name): return str(selva_dir / "ext" / name) - for path, label in [ - (w("video_enc_sup_5.pth"), "video_enc"), - (w(gen_filename), "generator"), - (wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), "VAE"), - ]: + vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth") + gen_weight = w(gen_filename) + for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]: if not Path(path).exists(): print(f"[LoRA] Missing weight: {path} ({label})") print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.") sys.exit(1) - synch_path = str(selva_dir / "synchformer_state_dict.pth") - if not Path(synch_path).exists(): - # Fallback: check prismaudio dir - alt = selva_dir.parent / "prismaudio" / "synchformer_state_dict.pth" - if alt.exists(): - synch_path = str(alt) - else: - print(f"[LoRA] Missing synchformer weights: {synch_path}") - sys.exit(1) + # --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) --- + print("[LoRA] Loading CLIP text encoder...") + clip_model = create_model_from_pretrained( + 'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False + ).to(device, dtype).eval() + clip_model = patch_clip(clip_model) + tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') - bigvgan_path = wext("best_netG.pt") if has_bigvgan else None - - # --- Load models --- - print(f"[LoRA] Loading TextSynch encoder...") - net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval() - net_video_enc.load_weights( - torch.load(w("video_enc_sup_5.pth"), map_location="cpu", weights_only=False) - ) + # --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) --- + print("[LoRA] Loading VAE encoder...") + feature_utils = FeaturesUtils( + tod_vae_ckpt=vae_weight, + enable_conditions=False, + mode=mode, + need_vae_encoder=True, + ).to(device, dtype).eval() + # --- Load generator --- print(f"[LoRA] Loading generator ({args.variant})...") net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval() net_generator.load_weights( - torch.load(w(gen_filename), map_location="cpu", weights_only=False) + torch.load(gen_weight, map_location="cpu", weights_only=False) ) - print("[LoRA] Loading FeaturesUtils (need_vae_encoder=True)...") - feature_utils = FeaturesUtils( - tod_vae_ckpt=wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), - synchformer_ckpt=synch_path, - enable_conditions=True, - mode=mode, - bigvgan_vocoder_ckpt=bigvgan_path, - need_vae_encoder=True, # required for audio → latent encoding during training - ).to(device, dtype).eval() - # --- Apply LoRA --- n_lora = apply_lora( net_generator, @@ -297,7 +241,6 @@ def main(): print(f"[LoRA] Trainable: {trainable:,} / {total:,} params " f"({100 * trainable / total:.2f}%)") - # Update rotary position embeddings for the fixed sequence lengths net_generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, clip_seq_len=seq_cfg.clip_seq_len, @@ -305,41 +248,48 @@ def main(): ) # --- Dataset --- - video_files = sorted( - p for p in data_dir.iterdir() - if p.suffix.lower() in _VIDEO_EXTS - ) - if not video_files: - print(f"[LoRA] No video files found in {data_dir}") + npz_files = sorted(data_dir.glob("*.npz")) + if not npz_files: + print(f"[LoRA] No .npz files found in {data_dir}") sys.exit(1) - print(f"[LoRA] Found {len(video_files)} video(s) in {data_dir}") - prompt_map = load_prompts(data_dir) - default_prompt = data_dir.name # use directory name as fallback prompt + prompt_map = load_prompts(data_dir) + default_prompt = data_dir.name - # Pre-extract features for all clips (cache in RAM) - print("[LoRA] Extracting features from all clips...") + print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...") dataset = [] - for vf in video_files: - prompt = prompt_map.get(vf.name, default_prompt) - print(f" {vf.name}: '{prompt}'") + for npz_path in npz_files: + audio_path = find_audio_for_npz(npz_path) + if audio_path is None: + print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping") + continue + + # Prompt priority: prompts.txt override > embedded in .npz > directory name + prompt = prompt_map.get(npz_path.name) + if prompt is None: + bundle = load_npz(npz_path) + prompt = bundle.get("prompt", default_prompt) + else: + bundle = load_npz(npz_path) + + print(f" {npz_path.name} + {audio_path.name}: '{prompt}'") + try: - video, audio, source_fps = load_clip(vf, sample_rate, duration) - x1, clip_f, sync_f, text_clip = extract_features( - video, audio, source_fps, prompt, duration, - feature_utils, net_video_enc, device, dtype, - ) - dataset.append((x1, clip_f, sync_f, text_clip)) + audio = load_audio(audio_path, sample_rate, duration) + x1 = extract_audio_latent(audio, feature_utils, device, dtype) + text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu() + dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) except Exception as e: - print(f" [LoRA] Warning: failed to process {vf.name}: {e}") + print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}") + if not dataset: print("[LoRA] No clips could be loaded.") sys.exit(1) - print(f"[LoRA] {len(dataset)} clips ready.") + print(f"[LoRA] {len(dataset)} clip(s) ready.") # --- Optimizer + LR scheduler --- lora_params = [p for p in net_generator.parameters() if p.requires_grad] - optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2) + optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2) def lr_lambda(step): if step < args.warmup_steps: @@ -347,7 +297,6 @@ def main(): return 1.0 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) # --- Training loop --- @@ -359,24 +308,19 @@ def main(): total_loss = 0.0 for step in range(1, args.steps + 1): - # Sample a random clip from the dataset x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) - x1 = x1_cpu.to(device, dtype) - clip_f = clip_f_cpu.to(device, dtype) - sync_f = sync_f_cpu.to(device, dtype) + x1 = x1_cpu.to(device, dtype) + clip_f = clip_f_cpu.to(device, dtype) + sync_f = sync_f_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype) - # Normalize latent in-place (net_generator.normalize is in-place) net_generator.normalize(x1) - # Flow matching step - t = torch.rand(1, device=device, dtype=dtype) # (1,) — one timestep + t = torch.rand(1, device=device, dtype=dtype) x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) - # Forward pass — gradients flow through LoRA A/B only - # forward(latent, clip_f, sync_f, text_f, t) takes raw feature tensors v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t) loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum @@ -390,7 +334,7 @@ def main(): optimizer.zero_grad() if step % 50 == 0: - avg = total_loss / 50 + avg = total_loss / 50 lr_now = scheduler.get_last_lr()[0] print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}") total_loss = 0.0 @@ -400,9 +344,9 @@ def main(): torch.save(get_lora_state_dict(net_generator), ckpt) print(f"[LoRA] Saved {ckpt}") - # Save final adapter with metadata + # Save final adapter with embedded metadata final = output_dir / "adapter_final.pt" - meta = { + meta = { "variant": args.variant, "rank": args.rank, "alpha": args.alpha if args.alpha is not None else float(args.rank),