refactor: train_lora accepts .npz + audio pairs instead of raw video

- Input is now pre-extracted .npz files (from SelvaFeatureExtractor) paired
  with clean audio files (same stem). Visual features no longer re-extracted
  during training.
- FeaturesUtils loaded with enable_conditions=False (VAE only) — Synchformer
  and T5 are no longer loaded, saving ~3-4 GB VRAM.
- CLIP text encoder loaded separately via patch_clip so text prompt can differ
  from the one used during feature extraction.
- Prompt priority: prompts.txt override > embedded in .npz > directory name.
- Removed: torchvision video loading, frame sampling/resizing, net_video_enc,
  synchformer path check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 15:14:26 +02:00
parent cde280049b
commit 1eb82d8050
+115 -171
View File
@@ -7,11 +7,12 @@ Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model)
Data layout: Data layout:
data/my_sound/ data/my_sound/
clip01.mp4 # video files — audio is extracted from the video track clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
clip02.mp4 clip01.wav # paired clean audio (same filename stem, any format)
prompts.txt # optional: "clip01.mp4: description of the sound" prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
If prompts.txt is absent, the directory name is used as the prompt for all clips. If prompts.txt is absent, the prompt embedded in each .npz is used.
If the .npz has no embedded prompt, the directory name is used as fallback.
Usage: Usage:
python train_lora.py \\ python train_lora.py \\
@@ -29,47 +30,40 @@ import random
import json import json
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from torchvision.io import read_video import open_clip
from open_clip import create_model_from_pretrained
sys.path.insert(0, os.path.dirname(__file__)) sys.path.insert(0, os.path.dirname(__file__))
from selva_core.model.networks_generator import get_my_mmaudio from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.networks_video_enc import get_my_textsynch from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
from selva_core.model.flow_matching import FlowMatching from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict from selva_core.model.lora import apply_lora, get_lora_state_dict
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants (mirror selva_feature_extractor.py) # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_CLIP_SIZE = 384
_SYNC_SIZE = 224
_CLIP_FPS = 8
_SYNC_FPS = 25
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_VARIANTS = { _VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True), "small_16k": ("generator_small_16k_sup_5.pth", "16k"),
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False), "small_44k": ("generator_small_44k_sup_5.pth", "44k"),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False), "medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False), "large_44k": ("generator_large_44k_sup_5.pth", "44k"),
} }
_VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv"} _AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Data helpers # Data helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def load_prompts(data_dir: Path) -> dict: def load_prompts(data_dir: Path) -> dict:
"""Load filename → prompt from prompts.txt. Returns empty dict if absent.""" """Load filename → prompt overrides from prompts.txt."""
p = data_dir / "prompts.txt" p = data_dir / "prompts.txt"
if not p.exists(): if not p.exists():
return {} return {}
@@ -84,105 +78,68 @@ def load_prompts(data_dir: Path) -> dict:
return mapping return mapping
def load_clip(path: Path, target_sr: int, duration: float): def find_audio_for_npz(npz_path: Path) -> Path | None:
"""Load a video file. """Find a paired audio file with the same stem as the .npz."""
for ext in _AUDIO_EXTS:
candidate = npz_path.with_suffix(ext)
if candidate.exists():
return candidate
return None
Returns:
video: [T, H, W, C] float32 [0, 1]
audio: [L] float32 [-1, 1], resampled and trimmed/padded to duration
source_fps: float
"""
video, audio, info = read_video(str(path), pts_unit="sec", output_format="THWC")
source_fps = float(info.get("video_fps", 30.0)) def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
audio_fps = int(info.get("audio_fps", target_sr)) """Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
waveform, sr = torchaudio.load(str(path))
# Video → float32 [0, 1] # Stereo → mono
video = video.float() / 255.0 # [T, H, W, C] if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
# Audio → mono float32 [-1, 1] # Resample
target_len = int(duration * target_sr) if sr != target_sr:
if audio.numel() == 0: waveform = torchaudio.functional.resample(
audio_out = torch.zeros(target_len) waveform.unsqueeze(0), sr, target_sr
else:
# audio shape: (channels, samples) — torchvision returns float in [-1, 1]
if audio.dim() == 2:
audio = audio.mean(0) # stereo → mono
elif audio.dim() == 1:
pass
audio = audio.float()
# Safety: clamp to [-1, 1] in case of PCM encoding
if audio.abs().max() > 1.0:
audio = audio / 32768.0
if audio_fps != target_sr:
audio = torchaudio.functional.resample(
audio.unsqueeze(0), audio_fps, target_sr
).squeeze(0) ).squeeze(0)
if audio.shape[0] >= target_len: target_len = int(duration * target_sr)
audio_out = audio[:target_len] if waveform.shape[0] >= target_len:
else: return waveform[:target_len]
audio_out = F.pad(audio, (0, target_len - audio.shape[0])) return F.pad(waveform, (0, target_len - waveform.shape[0]))
return video, audio_out, source_fps
def _sample_frames(video, source_fps, target_fps, duration): def load_npz(path: Path) -> dict:
T = video.shape[0] """Load a feature bundle produced by SelvaFeatureExtractor."""
n_out = max(1, int(duration * target_fps)) data = np.load(str(path), allow_pickle=False)
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)] bundle = {
return video[indices] "clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
if "variant" in data:
bundle["variant"] = str(data["variant"])
return bundle
def _resize_frames(frames, size): # ---------------------------------------------------------------------------
x = frames.permute(0, 3, 1, 2).float() # [N, C, H, W] # Feature extraction (audio + text only — visual features come from .npz)
x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False) # ---------------------------------------------------------------------------
return x.clamp(0.0, 1.0)
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
tokens = tokenizer(text).to(device)
with torch.inference_mode():
return clip_model.encode_text(tokens, normalize=True)
def extract_features(video, audio, source_fps, prompt, duration, def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
feature_utils, net_video_enc, device, dtype): """Encode a waveform to the generator's latent space via the VAE.
"""Extract all conditioning features from a single video+audio clip.
All returned tensors are on CPU, detached — ready to move to device for training. encode_audio is @inference_mode — .clone() is required before the autograd path.
""" """
with torch.no_grad(): audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
# --- Audio latent (VAE encode) ---
# encode_audio is @inference_mode and returns DiagonalGaussianDistribution
audio_b = audio.unsqueeze(0).to(feature_utils.device, dtype) # [1, L]
dist = feature_utils.encode_audio(audio_b) dist = feature_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu() # [1, seq_len, latent_dim] — .clone() exits inference mode return dist.mode().clone().cpu() # [1, seq_len, latent_dim]
# --- CLIP visual features ---
clip_frames = _sample_frames(video, source_fps, _CLIP_FPS, duration)
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
clip_f = feature_utils.encode_video_with_clip(clip_input).cpu() # [1, N, 1024]
# --- Sync (TextSynchformer) features ---
sync_frames = _sample_frames(video, source_fps, _SYNC_FPS, duration)
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat(
[sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
text_t5, text_mask = feature_utils.encode_text_t5([prompt])
text_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_t5, text_mask)
sync_f = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_t5, text_mask=text_mask
).cpu() # [1, T_sync, 768]
# --- CLIP text features ---
text_clip = feature_utils.encode_text_clip([prompt]).cpu() # [1, 77, D]
return x1, clip_f, sync_f, text_clip
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -191,7 +148,7 @@ def extract_features(video, audio, source_fps, prompt, duration,
def main(): def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator") parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
parser.add_argument("--data_dir", required=True, help="Directory with video files and optional prompts.txt") parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
parser.add_argument("--output_dir", default="lora_output") parser.add_argument("--output_dir", default="lora_output")
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys())) parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)") parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
@@ -222,7 +179,7 @@ def main():
selva_dir = Path(args.selva_dir) selva_dir = Path(args.selva_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
gen_filename, mode, has_bigvgan = _VARIANTS[args.variant] gen_filename, mode = _VARIANTS[args.variant]
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
duration = seq_cfg.duration duration = seq_cfg.duration
sample_rate = seq_cfg.sampling_rate sample_rate = seq_cfg.sampling_rate
@@ -231,51 +188,38 @@ def main():
def w(name): return str(selva_dir / name) def w(name): return str(selva_dir / name)
def wext(name): return str(selva_dir / "ext" / name) def wext(name): return str(selva_dir / "ext" / name)
for path, label in [ vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
(w("video_enc_sup_5.pth"), "video_enc"), gen_weight = w(gen_filename)
(w(gen_filename), "generator"), for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
(wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), "VAE"),
]:
if not Path(path).exists(): if not Path(path).exists():
print(f"[LoRA] Missing weight: {path} ({label})") print(f"[LoRA] Missing weight: {path} ({label})")
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.") print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
sys.exit(1) sys.exit(1)
synch_path = str(selva_dir / "synchformer_state_dict.pth") # --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
if not Path(synch_path).exists(): print("[LoRA] Loading CLIP text encoder...")
# Fallback: check prismaudio dir clip_model = create_model_from_pretrained(
alt = selva_dir.parent / "prismaudio" / "synchformer_state_dict.pth" 'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
if alt.exists(): ).to(device, dtype).eval()
synch_path = str(alt) clip_model = patch_clip(clip_model)
else: tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
print(f"[LoRA] Missing synchformer weights: {synch_path}")
sys.exit(1)
bigvgan_path = wext("best_netG.pt") if has_bigvgan else None # --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
print("[LoRA] Loading VAE encoder...")
# --- Load models --- feature_utils = FeaturesUtils(
print(f"[LoRA] Loading TextSynch encoder...") tod_vae_ckpt=vae_weight,
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval() enable_conditions=False,
net_video_enc.load_weights( mode=mode,
torch.load(w("video_enc_sup_5.pth"), map_location="cpu", weights_only=False) need_vae_encoder=True,
) ).to(device, dtype).eval()
# --- Load generator ---
print(f"[LoRA] Loading generator ({args.variant})...") print(f"[LoRA] Loading generator ({args.variant})...")
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval() net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
net_generator.load_weights( net_generator.load_weights(
torch.load(w(gen_filename), map_location="cpu", weights_only=False) torch.load(gen_weight, map_location="cpu", weights_only=False)
) )
print("[LoRA] Loading FeaturesUtils (need_vae_encoder=True)...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=wext("v1-16.pth" if mode == "16k" else "v1-44.pth"),
synchformer_ckpt=synch_path,
enable_conditions=True,
mode=mode,
bigvgan_vocoder_ckpt=bigvgan_path,
need_vae_encoder=True, # required for audio → latent encoding during training
).to(device, dtype).eval()
# --- Apply LoRA --- # --- Apply LoRA ---
n_lora = apply_lora( n_lora = apply_lora(
net_generator, net_generator,
@@ -297,7 +241,6 @@ def main():
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params " print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
f"({100 * trainable / total:.2f}%)") f"({100 * trainable / total:.2f}%)")
# Update rotary position embeddings for the fixed sequence lengths
net_generator.update_seq_lengths( net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len, latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len, clip_seq_len=seq_cfg.clip_seq_len,
@@ -305,37 +248,44 @@ def main():
) )
# --- Dataset --- # --- Dataset ---
video_files = sorted( npz_files = sorted(data_dir.glob("*.npz"))
p for p in data_dir.iterdir() if not npz_files:
if p.suffix.lower() in _VIDEO_EXTS print(f"[LoRA] No .npz files found in {data_dir}")
)
if not video_files:
print(f"[LoRA] No video files found in {data_dir}")
sys.exit(1) sys.exit(1)
print(f"[LoRA] Found {len(video_files)} video(s) in {data_dir}")
prompt_map = load_prompts(data_dir) prompt_map = load_prompts(data_dir)
default_prompt = data_dir.name # use directory name as fallback prompt default_prompt = data_dir.name
# Pre-extract features for all clips (cache in RAM) print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
print("[LoRA] Extracting features from all clips...")
dataset = [] dataset = []
for vf in video_files: for npz_path in npz_files:
prompt = prompt_map.get(vf.name, default_prompt) audio_path = find_audio_for_npz(npz_path)
print(f" {vf.name}: '{prompt}'") if audio_path is None:
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
continue
# Prompt priority: prompts.txt override > embedded in .npz > directory name
prompt = prompt_map.get(npz_path.name)
if prompt is None:
bundle = load_npz(npz_path)
prompt = bundle.get("prompt", default_prompt)
else:
bundle = load_npz(npz_path)
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
try: try:
video, audio, source_fps = load_clip(vf, sample_rate, duration) audio = load_audio(audio_path, sample_rate, duration)
x1, clip_f, sync_f, text_clip = extract_features( x1 = extract_audio_latent(audio, feature_utils, device, dtype)
video, audio, source_fps, prompt, duration, text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
feature_utils, net_video_enc, device, dtype, dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
)
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e: except Exception as e:
print(f" [LoRA] Warning: failed to process {vf.name}: {e}") print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
if not dataset: if not dataset:
print("[LoRA] No clips could be loaded.") print("[LoRA] No clips could be loaded.")
sys.exit(1) sys.exit(1)
print(f"[LoRA] {len(dataset)} clips ready.") print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler --- # --- Optimizer + LR scheduler ---
lora_params = [p for p in net_generator.parameters() if p.requires_grad] lora_params = [p for p in net_generator.parameters() if p.requires_grad]
@@ -347,7 +297,6 @@ def main():
return 1.0 return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Training loop --- # --- Training loop ---
@@ -359,7 +308,6 @@ def main():
total_loss = 0.0 total_loss = 0.0
for step in range(1, args.steps + 1): for step in range(1, args.steps + 1):
# Sample a random clip from the dataset
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
x1 = x1_cpu.to(device, dtype) x1 = x1_cpu.to(device, dtype)
@@ -367,16 +315,12 @@ def main():
sync_f = sync_f_cpu.to(device, dtype) sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype)
# Normalize latent in-place (net_generator.normalize is in-place)
net_generator.normalize(x1) net_generator.normalize(x1)
# Flow matching step t = torch.rand(1, device=device, dtype=dtype)
t = torch.rand(1, device=device, dtype=dtype) # (1,) — one timestep
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)
# Forward pass — gradients flow through LoRA A/B only
# forward(latent, clip_f, sync_f, text_f, t) takes raw feature tensors
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t) v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
@@ -400,7 +344,7 @@ def main():
torch.save(get_lora_state_dict(net_generator), ckpt) torch.save(get_lora_state_dict(net_generator), ckpt)
print(f"[LoRA] Saved {ckpt}") print(f"[LoRA] Saved {ckpt}")
# Save final adapter with metadata # Save final adapter with embedded metadata
final = output_dir / "adapter_final.pt" final = output_dir / "adapter_final.pt"
meta = { meta = {
"variant": args.variant, "variant": args.variant,