refactor: train_lora accepts .npz + audio pairs instead of raw video

- Input is now pre-extracted .npz files (from SelvaFeatureExtractor) paired
  with clean audio files (same stem). Visual features no longer re-extracted
  during training.
- FeaturesUtils loaded with enable_conditions=False (VAE only) — Synchformer
  and T5 are no longer loaded, saving ~3-4 GB VRAM.
- CLIP text encoder loaded separately via patch_clip so text prompt can differ
  from the one used during feature extraction.
- Prompt priority: prompts.txt override > embedded in .npz > directory name.
- Removed: torchvision video loading, frame sampling/resizing, net_video_enc,
  synchformer path check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 15:14:26 +02:00
parent cde280049b
commit 1eb82d8050
+115 -171
View File
@@ -7,11 +7,12 @@ Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model)
Data layout:
data/my_sound/
clip01.mp4 # video files — audio is extracted from the video track
clip02.mp4
prompts.txt # optional: "clip01.mp4: description of the sound"
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
clip01.wav # paired clean audio (same filename stem, any format)
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
If prompts.txt is absent, the directory name is used as the prompt for all clips.
If prompts.txt is absent, the prompt embedded in each .npz is used.
If the .npz has no embedded prompt, the directory name is used as fallback.
Usage:
python train_lora.py \\
@@ -29,47 +30,40 @@ import random
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from torchvision.io import read_video
import open_clip
from open_clip import create_model_from_pretrained
sys.path.insert(0, os.path.dirname(__file__))
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.networks_video_enc import get_my_textsynch
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict
# ---------------------------------------------------------------------------
# Constants (mirror selva_feature_extractor.py)
# Constants
# ---------------------------------------------------------------------------
_CLIP_SIZE = 384
_SYNC_SIZE = 224
_CLIP_FPS = 8
_SYNC_FPS = 25
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
}
_VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv"}
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def load_prompts(data_dir: Path) -> dict:
"""Load filename → prompt from prompts.txt. Returns empty dict if absent."""
"""Load filename → prompt overrides from prompts.txt."""
p = data_dir / "prompts.txt"
if not p.exists():
return {}
@@ -84,105 +78,68 @@ def load_prompts(data_dir: Path) -> dict:
return mapping
def load_clip(path: Path, target_sr: int, duration: float):
"""Load a video file.
def find_audio_for_npz(npz_path: Path) -> Path | None:
"""Find a paired audio file with the same stem as the .npz."""
for ext in _AUDIO_EXTS:
candidate = npz_path.with_suffix(ext)
if candidate.exists():
return candidate
return None
Returns:
video: [T, H, W, C] float32 [0, 1]
audio: [L] float32 [-1, 1], resampled and trimmed/padded to duration
source_fps: float
"""
video, audio, info = read_video(str(path), pts_unit="sec", output_format="THWC")
source_fps = float(info.get("video_fps", 30.0))
audio_fps = int(info.get("audio_fps", target_sr))
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
waveform, sr = torchaudio.load(str(path))
# Video → float32 [0, 1]
video = video.float() / 255.0 # [T, H, W, C]
# Stereo → mono
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
# Audio → mono float32 [-1, 1]
target_len = int(duration * target_sr)
if audio.numel() == 0:
audio_out = torch.zeros(target_len)
else:
# audio shape: (channels, samples) — torchvision returns float in [-1, 1]
if audio.dim() == 2:
audio = audio.mean(0) # stereo → mono
elif audio.dim() == 1:
pass
audio = audio.float()
# Safety: clamp to [-1, 1] in case of PCM encoding
if audio.abs().max() > 1.0:
audio = audio / 32768.0
if audio_fps != target_sr:
audio = torchaudio.functional.resample(
audio.unsqueeze(0), audio_fps, target_sr
# Resample
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr
).squeeze(0)
if audio.shape[0] >= target_len:
audio_out = audio[:target_len]
else:
audio_out = F.pad(audio, (0, target_len - audio.shape[0]))
return video, audio_out, source_fps
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def _sample_frames(video, source_fps, target_fps, duration):
T = video.shape[0]
n_out = max(1, int(duration * target_fps))
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
return video[indices]
def load_npz(path: Path) -> dict:
"""Load a feature bundle produced by SelvaFeatureExtractor."""
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
if "variant" in data:
bundle["variant"] = str(data["variant"])
return bundle
def _resize_frames(frames, size):
x = frames.permute(0, 3, 1, 2).float() # [N, C, H, W]
x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False)
return x.clamp(0.0, 1.0)
# ---------------------------------------------------------------------------
# Feature extraction (audio + text only — visual features come from .npz)
# ---------------------------------------------------------------------------
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
tokens = tokenizer(text).to(device)
with torch.inference_mode():
return clip_model.encode_text(tokens, normalize=True)
def extract_features(video, audio, source_fps, prompt, duration,
feature_utils, net_video_enc, device, dtype):
"""Extract all conditioning features from a single video+audio clip.
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
"""Encode a waveform to the generator's latent space via the VAE.
All returned tensors are on CPU, detached — ready to move to device for training.
encode_audio is @inference_mode — .clone() is required before the autograd path.
"""
with torch.no_grad():
# --- Audio latent (VAE encode) ---
# encode_audio is @inference_mode and returns DiagonalGaussianDistribution
audio_b = audio.unsqueeze(0).to(feature_utils.device, dtype) # [1, L]
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
dist = feature_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu() # [1, seq_len, latent_dim] — .clone() exits inference mode
# --- CLIP visual features ---
clip_frames = _sample_frames(video, source_fps, _CLIP_FPS, duration)
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
clip_f = feature_utils.encode_video_with_clip(clip_input).cpu() # [1, N, 1024]
# --- Sync (TextSynchformer) features ---
sync_frames = _sample_frames(video, source_fps, _SYNC_FPS, duration)
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat(
[sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
text_t5, text_mask = feature_utils.encode_text_t5([prompt])
text_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_t5, text_mask)
sync_f = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_t5, text_mask=text_mask
).cpu() # [1, T_sync, 768]
# --- CLIP text features ---
text_clip = feature_utils.encode_text_clip([prompt]).cpu() # [1, 77, D]
return x1, clip_f, sync_f, text_clip
return dist.mode().clone().cpu() # [1, seq_len, latent_dim]
# ---------------------------------------------------------------------------
@@ -191,7 +148,7 @@ def extract_features(video, audio, source_fps, prompt, duration,
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
parser.add_argument("--data_dir", required=True, help="Directory with video files and optional prompts.txt")
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
parser.add_argument("--output_dir", default="lora_output")
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
@@ -222,7 +179,7 @@ def main():
selva_dir = Path(args.selva_dir)
output_dir.mkdir(parents=True, exist_ok=True)
gen_filename, mode, has_bigvgan = _VARIANTS[args.variant]
gen_filename, mode = _VARIANTS[args.variant]
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
duration = seq_cfg.duration
sample_rate = seq_cfg.sampling_rate
@@ -231,51 +188,38 @@ def main():
def w(name): return str(selva_dir / name)
def wext(name): return str(selva_dir / "ext" / name)
for path, label in [
(w("video_enc_sup_5.pth"), "video_enc"),
(w(gen_filename), "generator"),
(wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), "VAE"),
]:
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
gen_weight = w(gen_filename)
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
if not Path(path).exists():
print(f"[LoRA] Missing weight: {path} ({label})")
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
sys.exit(1)
synch_path = str(selva_dir / "synchformer_state_dict.pth")
if not Path(synch_path).exists():
# Fallback: check prismaudio dir
alt = selva_dir.parent / "prismaudio" / "synchformer_state_dict.pth"
if alt.exists():
synch_path = str(alt)
else:
print(f"[LoRA] Missing synchformer weights: {synch_path}")
sys.exit(1)
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
print("[LoRA] Loading CLIP text encoder...")
clip_model = create_model_from_pretrained(
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
).to(device, dtype).eval()
clip_model = patch_clip(clip_model)
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
bigvgan_path = wext("best_netG.pt") if has_bigvgan else None
# --- Load models ---
print(f"[LoRA] Loading TextSynch encoder...")
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
net_video_enc.load_weights(
torch.load(w("video_enc_sup_5.pth"), map_location="cpu", weights_only=False)
)
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
print("[LoRA] Loading VAE encoder...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_weight,
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device, dtype).eval()
# --- Load generator ---
print(f"[LoRA] Loading generator ({args.variant})...")
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(w(gen_filename), map_location="cpu", weights_only=False)
torch.load(gen_weight, map_location="cpu", weights_only=False)
)
print("[LoRA] Loading FeaturesUtils (need_vae_encoder=True)...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=wext("v1-16.pth" if mode == "16k" else "v1-44.pth"),
synchformer_ckpt=synch_path,
enable_conditions=True,
mode=mode,
bigvgan_vocoder_ckpt=bigvgan_path,
need_vae_encoder=True, # required for audio → latent encoding during training
).to(device, dtype).eval()
# --- Apply LoRA ---
n_lora = apply_lora(
net_generator,
@@ -297,7 +241,6 @@ def main():
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
f"({100 * trainable / total:.2f}%)")
# Update rotary position embeddings for the fixed sequence lengths
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
@@ -305,37 +248,44 @@ def main():
)
# --- Dataset ---
video_files = sorted(
p for p in data_dir.iterdir()
if p.suffix.lower() in _VIDEO_EXTS
)
if not video_files:
print(f"[LoRA] No video files found in {data_dir}")
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
print(f"[LoRA] No .npz files found in {data_dir}")
sys.exit(1)
print(f"[LoRA] Found {len(video_files)} video(s) in {data_dir}")
prompt_map = load_prompts(data_dir)
default_prompt = data_dir.name # use directory name as fallback prompt
default_prompt = data_dir.name
# Pre-extract features for all clips (cache in RAM)
print("[LoRA] Extracting features from all clips...")
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
dataset = []
for vf in video_files:
prompt = prompt_map.get(vf.name, default_prompt)
print(f" {vf.name}: '{prompt}'")
for npz_path in npz_files:
audio_path = find_audio_for_npz(npz_path)
if audio_path is None:
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
continue
# Prompt priority: prompts.txt override > embedded in .npz > directory name
prompt = prompt_map.get(npz_path.name)
if prompt is None:
bundle = load_npz(npz_path)
prompt = bundle.get("prompt", default_prompt)
else:
bundle = load_npz(npz_path)
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
try:
video, audio, source_fps = load_clip(vf, sample_rate, duration)
x1, clip_f, sync_f, text_clip = extract_features(
video, audio, source_fps, prompt, duration,
feature_utils, net_video_enc, device, dtype,
)
dataset.append((x1, clip_f, sync_f, text_clip))
audio = load_audio(audio_path, sample_rate, duration)
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
except Exception as e:
print(f" [LoRA] Warning: failed to process {vf.name}: {e}")
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
if not dataset:
print("[LoRA] No clips could be loaded.")
sys.exit(1)
print(f"[LoRA] {len(dataset)} clips ready.")
print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler ---
lora_params = [p for p in net_generator.parameters() if p.requires_grad]
@@ -347,7 +297,6 @@ def main():
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Training loop ---
@@ -359,7 +308,6 @@ def main():
total_loss = 0.0
for step in range(1, args.steps + 1):
# Sample a random clip from the dataset
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
x1 = x1_cpu.to(device, dtype)
@@ -367,16 +315,12 @@ def main():
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
# Normalize latent in-place (net_generator.normalize is in-place)
net_generator.normalize(x1)
# Flow matching step
t = torch.rand(1, device=device, dtype=dtype) # (1,) — one timestep
t = torch.rand(1, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
# Forward pass — gradients flow through LoRA A/B only
# forward(latent, clip_f, sync_f, text_f, t) takes raw feature tensors
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
@@ -400,7 +344,7 @@ def main():
torch.save(get_lora_state_dict(net_generator), ckpt)
print(f"[LoRA] Saved {ckpt}")
# Save final adapter with metadata
# Save final adapter with embedded metadata
final = output_dir / "adapter_final.pt"
meta = {
"variant": args.variant,