refactor: train_lora accepts .npz + audio pairs instead of raw video
- Input is now pre-extracted .npz files (from SelvaFeatureExtractor) paired with clean audio files (same stem). Visual features no longer re-extracted during training. - FeaturesUtils loaded with enable_conditions=False (VAE only) — Synchformer and T5 are no longer loaded, saving ~3-4 GB VRAM. - CLIP text encoder loaded separately via patch_clip so text prompt can differ from the one used during feature extraction. - Prompt priority: prompts.txt override > embedded in .npz > directory name. - Removed: torchvision video loading, frame sampling/resizing, net_video_enc, synchformer path check. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+115
-171
@@ -7,11 +7,12 @@ Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model)
|
||||
|
||||
Data layout:
|
||||
data/my_sound/
|
||||
clip01.mp4 # video files — audio is extracted from the video track
|
||||
clip02.mp4
|
||||
prompts.txt # optional: "clip01.mp4: description of the sound"
|
||||
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
|
||||
clip01.wav # paired clean audio (same filename stem, any format)
|
||||
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
|
||||
|
||||
If prompts.txt is absent, the directory name is used as the prompt for all clips.
|
||||
If prompts.txt is absent, the prompt embedded in each .npz is used.
|
||||
If the .npz has no embedded prompt, the directory name is used as fallback.
|
||||
|
||||
Usage:
|
||||
python train_lora.py \\
|
||||
@@ -29,47 +30,40 @@ import random
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from torchvision.io import read_video
|
||||
import open_clip
|
||||
from open_clip import create_model_from_pretrained
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from selva_core.model.networks_generator import get_my_mmaudio
|
||||
from selva_core.model.networks_video_enc import get_my_textsynch
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
|
||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
from selva_core.model.lora import apply_lora, get_lora_state_dict
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants (mirror selva_feature_extractor.py)
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_SYNC_SIZE = 224
|
||||
_CLIP_FPS = 8
|
||||
_SYNC_FPS = 25
|
||||
|
||||
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||
|
||||
_VARIANTS = {
|
||||
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
||||
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
||||
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
|
||||
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
|
||||
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
|
||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
|
||||
}
|
||||
|
||||
_VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv"}
|
||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_prompts(data_dir: Path) -> dict:
|
||||
"""Load filename → prompt from prompts.txt. Returns empty dict if absent."""
|
||||
"""Load filename → prompt overrides from prompts.txt."""
|
||||
p = data_dir / "prompts.txt"
|
||||
if not p.exists():
|
||||
return {}
|
||||
@@ -84,105 +78,68 @@ def load_prompts(data_dir: Path) -> dict:
|
||||
return mapping
|
||||
|
||||
|
||||
def load_clip(path: Path, target_sr: int, duration: float):
|
||||
"""Load a video file.
|
||||
def find_audio_for_npz(npz_path: Path) -> Path | None:
|
||||
"""Find a paired audio file with the same stem as the .npz."""
|
||||
for ext in _AUDIO_EXTS:
|
||||
candidate = npz_path.with_suffix(ext)
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
Returns:
|
||||
video: [T, H, W, C] float32 [0, 1]
|
||||
audio: [L] float32 [-1, 1], resampled and trimmed/padded to duration
|
||||
source_fps: float
|
||||
"""
|
||||
video, audio, info = read_video(str(path), pts_unit="sec", output_format="THWC")
|
||||
|
||||
source_fps = float(info.get("video_fps", 30.0))
|
||||
audio_fps = int(info.get("audio_fps", target_sr))
|
||||
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
||||
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
|
||||
waveform, sr = torchaudio.load(str(path))
|
||||
|
||||
# Video → float32 [0, 1]
|
||||
video = video.float() / 255.0 # [T, H, W, C]
|
||||
# Stereo → mono
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform.mean(0, keepdim=True)
|
||||
waveform = waveform.squeeze(0).float()
|
||||
|
||||
# Audio → mono float32 [-1, 1]
|
||||
target_len = int(duration * target_sr)
|
||||
if audio.numel() == 0:
|
||||
audio_out = torch.zeros(target_len)
|
||||
else:
|
||||
# audio shape: (channels, samples) — torchvision returns float in [-1, 1]
|
||||
if audio.dim() == 2:
|
||||
audio = audio.mean(0) # stereo → mono
|
||||
elif audio.dim() == 1:
|
||||
pass
|
||||
audio = audio.float()
|
||||
|
||||
# Safety: clamp to [-1, 1] in case of PCM encoding
|
||||
if audio.abs().max() > 1.0:
|
||||
audio = audio / 32768.0
|
||||
|
||||
if audio_fps != target_sr:
|
||||
audio = torchaudio.functional.resample(
|
||||
audio.unsqueeze(0), audio_fps, target_sr
|
||||
# Resample
|
||||
if sr != target_sr:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform.unsqueeze(0), sr, target_sr
|
||||
).squeeze(0)
|
||||
|
||||
if audio.shape[0] >= target_len:
|
||||
audio_out = audio[:target_len]
|
||||
else:
|
||||
audio_out = F.pad(audio, (0, target_len - audio.shape[0]))
|
||||
|
||||
return video, audio_out, source_fps
|
||||
target_len = int(duration * target_sr)
|
||||
if waveform.shape[0] >= target_len:
|
||||
return waveform[:target_len]
|
||||
return F.pad(waveform, (0, target_len - waveform.shape[0]))
|
||||
|
||||
|
||||
def _sample_frames(video, source_fps, target_fps, duration):
|
||||
T = video.shape[0]
|
||||
n_out = max(1, int(duration * target_fps))
|
||||
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
||||
return video[indices]
|
||||
def load_npz(path: Path) -> dict:
|
||||
"""Load a feature bundle produced by SelvaFeatureExtractor."""
|
||||
data = np.load(str(path), allow_pickle=False)
|
||||
bundle = {
|
||||
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
|
||||
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
|
||||
}
|
||||
if "prompt" in data:
|
||||
bundle["prompt"] = str(data["prompt"])
|
||||
if "variant" in data:
|
||||
bundle["variant"] = str(data["variant"])
|
||||
return bundle
|
||||
|
||||
|
||||
def _resize_frames(frames, size):
|
||||
x = frames.permute(0, 3, 1, 2).float() # [N, C, H, W]
|
||||
x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False)
|
||||
return x.clamp(0.0, 1.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature extraction (audio + text only — visual features come from .npz)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
|
||||
tokens = tokenizer(text).to(device)
|
||||
with torch.inference_mode():
|
||||
return clip_model.encode_text(tokens, normalize=True)
|
||||
|
||||
|
||||
def extract_features(video, audio, source_fps, prompt, duration,
|
||||
feature_utils, net_video_enc, device, dtype):
|
||||
"""Extract all conditioning features from a single video+audio clip.
|
||||
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
|
||||
"""Encode a waveform to the generator's latent space via the VAE.
|
||||
|
||||
All returned tensors are on CPU, detached — ready to move to device for training.
|
||||
encode_audio is @inference_mode — .clone() is required before the autograd path.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# --- Audio latent (VAE encode) ---
|
||||
# encode_audio is @inference_mode and returns DiagonalGaussianDistribution
|
||||
audio_b = audio.unsqueeze(0).to(feature_utils.device, dtype) # [1, L]
|
||||
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
|
||||
dist = feature_utils.encode_audio(audio_b)
|
||||
x1 = dist.mode().clone().cpu() # [1, seq_len, latent_dim] — .clone() exits inference mode
|
||||
|
||||
# --- CLIP visual features ---
|
||||
clip_frames = _sample_frames(video, source_fps, _CLIP_FPS, duration)
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
clip_f = feature_utils.encode_video_with_clip(clip_input).cpu() # [1, N, 1024]
|
||||
|
||||
# --- Sync (TextSynchformer) features ---
|
||||
sync_frames = _sample_frames(video, source_fps, _SYNC_FPS, duration)
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
if sync_frames.shape[0] < 16:
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
sync_frames = torch.cat(
|
||||
[sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||
std = _SYNC_STD.to(sync_frames.device)
|
||||
sync_frames = (sync_frames - mean) / std
|
||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||
|
||||
text_t5, text_mask = feature_utils.encode_text_t5([prompt])
|
||||
text_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_t5, text_mask)
|
||||
sync_f = net_video_enc.encode_video_with_sync(
|
||||
sync_input, text_f=text_t5, text_mask=text_mask
|
||||
).cpu() # [1, T_sync, 768]
|
||||
|
||||
# --- CLIP text features ---
|
||||
text_clip = feature_utils.encode_text_clip([prompt]).cpu() # [1, 77, D]
|
||||
|
||||
return x1, clip_f, sync_f, text_clip
|
||||
return dist.mode().clone().cpu() # [1, seq_len, latent_dim]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -191,7 +148,7 @@ def extract_features(video, audio, source_fps, prompt, duration,
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
|
||||
parser.add_argument("--data_dir", required=True, help="Directory with video files and optional prompts.txt")
|
||||
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
|
||||
parser.add_argument("--output_dir", default="lora_output")
|
||||
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
|
||||
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
|
||||
@@ -222,7 +179,7 @@ def main():
|
||||
selva_dir = Path(args.selva_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gen_filename, mode, has_bigvgan = _VARIANTS[args.variant]
|
||||
gen_filename, mode = _VARIANTS[args.variant]
|
||||
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
||||
duration = seq_cfg.duration
|
||||
sample_rate = seq_cfg.sampling_rate
|
||||
@@ -231,51 +188,38 @@ def main():
|
||||
def w(name): return str(selva_dir / name)
|
||||
def wext(name): return str(selva_dir / "ext" / name)
|
||||
|
||||
for path, label in [
|
||||
(w("video_enc_sup_5.pth"), "video_enc"),
|
||||
(w(gen_filename), "generator"),
|
||||
(wext("v1-16.pth" if mode == "16k" else "v1-44.pth"), "VAE"),
|
||||
]:
|
||||
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
|
||||
gen_weight = w(gen_filename)
|
||||
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
|
||||
if not Path(path).exists():
|
||||
print(f"[LoRA] Missing weight: {path} ({label})")
|
||||
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
|
||||
sys.exit(1)
|
||||
|
||||
synch_path = str(selva_dir / "synchformer_state_dict.pth")
|
||||
if not Path(synch_path).exists():
|
||||
# Fallback: check prismaudio dir
|
||||
alt = selva_dir.parent / "prismaudio" / "synchformer_state_dict.pth"
|
||||
if alt.exists():
|
||||
synch_path = str(alt)
|
||||
else:
|
||||
print(f"[LoRA] Missing synchformer weights: {synch_path}")
|
||||
sys.exit(1)
|
||||
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
|
||||
print("[LoRA] Loading CLIP text encoder...")
|
||||
clip_model = create_model_from_pretrained(
|
||||
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
|
||||
).to(device, dtype).eval()
|
||||
clip_model = patch_clip(clip_model)
|
||||
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
|
||||
|
||||
bigvgan_path = wext("best_netG.pt") if has_bigvgan else None
|
||||
|
||||
# --- Load models ---
|
||||
print(f"[LoRA] Loading TextSynch encoder...")
|
||||
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||
net_video_enc.load_weights(
|
||||
torch.load(w("video_enc_sup_5.pth"), map_location="cpu", weights_only=False)
|
||||
)
|
||||
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
|
||||
print("[LoRA] Loading VAE encoder...")
|
||||
feature_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=vae_weight,
|
||||
enable_conditions=False,
|
||||
mode=mode,
|
||||
need_vae_encoder=True,
|
||||
).to(device, dtype).eval()
|
||||
|
||||
# --- Load generator ---
|
||||
print(f"[LoRA] Loading generator ({args.variant})...")
|
||||
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
|
||||
net_generator.load_weights(
|
||||
torch.load(w(gen_filename), map_location="cpu", weights_only=False)
|
||||
torch.load(gen_weight, map_location="cpu", weights_only=False)
|
||||
)
|
||||
|
||||
print("[LoRA] Loading FeaturesUtils (need_vae_encoder=True)...")
|
||||
feature_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=wext("v1-16.pth" if mode == "16k" else "v1-44.pth"),
|
||||
synchformer_ckpt=synch_path,
|
||||
enable_conditions=True,
|
||||
mode=mode,
|
||||
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||
need_vae_encoder=True, # required for audio → latent encoding during training
|
||||
).to(device, dtype).eval()
|
||||
|
||||
# --- Apply LoRA ---
|
||||
n_lora = apply_lora(
|
||||
net_generator,
|
||||
@@ -297,7 +241,6 @@ def main():
|
||||
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
|
||||
f"({100 * trainable / total:.2f}%)")
|
||||
|
||||
# Update rotary position embeddings for the fixed sequence lengths
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=seq_cfg.clip_seq_len,
|
||||
@@ -305,37 +248,44 @@ def main():
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
video_files = sorted(
|
||||
p for p in data_dir.iterdir()
|
||||
if p.suffix.lower() in _VIDEO_EXTS
|
||||
)
|
||||
if not video_files:
|
||||
print(f"[LoRA] No video files found in {data_dir}")
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
if not npz_files:
|
||||
print(f"[LoRA] No .npz files found in {data_dir}")
|
||||
sys.exit(1)
|
||||
print(f"[LoRA] Found {len(video_files)} video(s) in {data_dir}")
|
||||
|
||||
prompt_map = load_prompts(data_dir)
|
||||
default_prompt = data_dir.name # use directory name as fallback prompt
|
||||
default_prompt = data_dir.name
|
||||
|
||||
# Pre-extract features for all clips (cache in RAM)
|
||||
print("[LoRA] Extracting features from all clips...")
|
||||
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
|
||||
dataset = []
|
||||
for vf in video_files:
|
||||
prompt = prompt_map.get(vf.name, default_prompt)
|
||||
print(f" {vf.name}: '{prompt}'")
|
||||
for npz_path in npz_files:
|
||||
audio_path = find_audio_for_npz(npz_path)
|
||||
if audio_path is None:
|
||||
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
|
||||
continue
|
||||
|
||||
# Prompt priority: prompts.txt override > embedded in .npz > directory name
|
||||
prompt = prompt_map.get(npz_path.name)
|
||||
if prompt is None:
|
||||
bundle = load_npz(npz_path)
|
||||
prompt = bundle.get("prompt", default_prompt)
|
||||
else:
|
||||
bundle = load_npz(npz_path)
|
||||
|
||||
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
|
||||
|
||||
try:
|
||||
video, audio, source_fps = load_clip(vf, sample_rate, duration)
|
||||
x1, clip_f, sync_f, text_clip = extract_features(
|
||||
video, audio, source_fps, prompt, duration,
|
||||
feature_utils, net_video_enc, device, dtype,
|
||||
)
|
||||
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||
audio = load_audio(audio_path, sample_rate, duration)
|
||||
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
|
||||
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
||||
except Exception as e:
|
||||
print(f" [LoRA] Warning: failed to process {vf.name}: {e}")
|
||||
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
|
||||
|
||||
if not dataset:
|
||||
print("[LoRA] No clips could be loaded.")
|
||||
sys.exit(1)
|
||||
print(f"[LoRA] {len(dataset)} clips ready.")
|
||||
print(f"[LoRA] {len(dataset)} clip(s) ready.")
|
||||
|
||||
# --- Optimizer + LR scheduler ---
|
||||
lora_params = [p for p in net_generator.parameters() if p.requires_grad]
|
||||
@@ -347,7 +297,6 @@ def main():
|
||||
return 1.0
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||
|
||||
# --- Training loop ---
|
||||
@@ -359,7 +308,6 @@ def main():
|
||||
|
||||
total_loss = 0.0
|
||||
for step in range(1, args.steps + 1):
|
||||
# Sample a random clip from the dataset
|
||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||
|
||||
x1 = x1_cpu.to(device, dtype)
|
||||
@@ -367,16 +315,12 @@ def main():
|
||||
sync_f = sync_f_cpu.to(device, dtype)
|
||||
text_clip = text_clip_cpu.to(device, dtype)
|
||||
|
||||
# Normalize latent in-place (net_generator.normalize is in-place)
|
||||
net_generator.normalize(x1)
|
||||
|
||||
# Flow matching step
|
||||
t = torch.rand(1, device=device, dtype=dtype) # (1,) — one timestep
|
||||
t = torch.rand(1, device=device, dtype=dtype)
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
# Forward pass — gradients flow through LoRA A/B only
|
||||
# forward(latent, clip_f, sync_f, text_f, t) takes raw feature tensors
|
||||
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||||
|
||||
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
|
||||
@@ -400,7 +344,7 @@ def main():
|
||||
torch.save(get_lora_state_dict(net_generator), ckpt)
|
||||
print(f"[LoRA] Saved {ckpt}")
|
||||
|
||||
# Save final adapter with metadata
|
||||
# Save final adapter with embedded metadata
|
||||
final = output_dir / "adapter_final.pt"
|
||||
meta = {
|
||||
"variant": args.variant,
|
||||
|
||||
Reference in New Issue
Block a user