65dc549494
New _reference_metrics() computes LSD, MCD, and per-band correlation between eval samples and the original source audio at each checkpoint. Loads reference audio once before the training loop and logs metrics alongside existing spectral metrics. Also fix batch_size in lora_optimized_dataset.json (4 -> 16). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1094 lines
49 KiB
Python
1094 lines
49 KiB
Python
import copy
|
||
import json
|
||
import math
|
||
import random
|
||
import traceback
|
||
from pathlib import Path
|
||
|
||
|
||
class SkipExperiment(Exception):
|
||
"""Raised when skip_current.flag is found — signals the scheduler to move to the next experiment."""
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import torchaudio
|
||
from PIL import Image, ImageDraw
|
||
|
||
import comfy.utils
|
||
import folder_paths
|
||
|
||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||
from selva_core.model.flow_matching import FlowMatching
|
||
from selva_core.model.lora import (
|
||
apply_lora, get_lora_state_dict, get_lora_and_base_state_dict, load_lora,
|
||
spectral_surgery,
|
||
)
|
||
|
||
|
||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
||
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Data helpers (mirror train_lora.py)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _load_prompts(data_dir: Path) -> dict:
|
||
p = data_dir / "prompts.txt"
|
||
if not p.exists():
|
||
return {}
|
||
mapping = {}
|
||
for line in p.read_text(encoding="utf-8").splitlines():
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
if ":" in line:
|
||
fname, prompt = line.split(":", 1)
|
||
mapping[fname.strip()] = prompt.strip()
|
||
return mapping
|
||
|
||
|
||
def _find_audio(npz_path: Path) -> Path | None:
|
||
for ext in _AUDIO_EXTS:
|
||
c = npz_path.with_suffix(ext)
|
||
if c.exists():
|
||
return c
|
||
return None
|
||
|
||
|
||
def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
||
try:
|
||
waveform, sr = torchaudio.load(str(path))
|
||
except RuntimeError as e:
|
||
if "torchcodec" not in str(e).lower() and "libtorchcodec" not in str(e).lower():
|
||
raise
|
||
# torchcodec unavailable (FFmpeg shared libs missing) — fall back to soundfile
|
||
import soundfile as sf
|
||
data, sr = sf.read(str(path), always_2d=True) # [frames, channels]
|
||
waveform = torch.from_numpy(data.T).float() # [channels, frames]
|
||
if waveform.shape[0] > 1:
|
||
waveform = waveform.mean(0, keepdim=True)
|
||
waveform = waveform.squeeze(0).float()
|
||
if sr != target_sr:
|
||
if not getattr(_load_audio, "_sr_warned", False):
|
||
print(f"[LoRA Trainer] WARNING: audio sr={sr} ≠ target {target_sr}, resampling on-the-fly. "
|
||
f"Pre-resample with SelVA Dataset Resampler for faster loading.", flush=True)
|
||
_load_audio._sr_warned = True
|
||
waveform = torchaudio.functional.resample(
|
||
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
|
||
target_len = int(duration * target_sr)
|
||
if waveform.shape[0] >= target_len:
|
||
return waveform[:target_len]
|
||
return F.pad(waveform, (0, target_len - waveform.shape[0]))
|
||
|
||
|
||
def _load_npz(path: Path) -> dict:
|
||
data = np.load(str(path), allow_pickle=False)
|
||
bundle = {
|
||
"clip_features": torch.from_numpy(data["clip_features"]),
|
||
"sync_features": torch.from_numpy(data["sync_features"]),
|
||
}
|
||
if "prompt" in data:
|
||
bundle["prompt"] = str(data["prompt"])
|
||
return bundle
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Eval sample
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
||
num_steps: int = 25, seed: int = 42, clip_idx: int = 0):
|
||
"""Run a quick no-CFG inference pass on a training clip.
|
||
|
||
Uses dataset[clip_idx] and a fixed noise seed so samples across checkpoints
|
||
are directly comparable — you can hear the model improve step by step.
|
||
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
|
||
"""
|
||
generator.eval()
|
||
try:
|
||
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
||
clip_f = clip_f_cpu.to(device, dtype)
|
||
sync_f = sync_f_cpu.to(device, dtype)
|
||
text_clip = text_clip_cpu.to(device, dtype)
|
||
|
||
rng = torch.Generator(device=device).manual_seed(seed)
|
||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||
device=device, dtype=dtype, generator=rng)
|
||
|
||
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||
|
||
def velocity_fn(t, x):
|
||
return generator.forward(x, clip_f, sync_f, text_clip,
|
||
t.reshape(1).to(device, dtype))
|
||
|
||
with torch.no_grad():
|
||
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
||
# .clone() strips inference-mode flag from x1_pred (computed from
|
||
# inference-mode weights) so unnormalize's in-place ops don't fail.
|
||
x1_unnorm = generator.unnormalize(x1_pred.clone())
|
||
|
||
# Only move the VAE+vocoder (tod) to GPU — avoids moving the
|
||
# entire FeaturesUtils (CLIP, T5, Synchformer) which wastes VRAM
|
||
# and fixes mixed-device state issues where the first parameter
|
||
# check could miss CPU-stranded vocoder weights.
|
||
tod = feature_utils_orig.tod
|
||
tod_orig_device = next(tod.parameters()).device
|
||
tod.to(device)
|
||
try:
|
||
spec = feature_utils_orig.decode(x1_unnorm)
|
||
audio = feature_utils_orig.vocode(spec)
|
||
finally:
|
||
tod.to(tod_orig_device)
|
||
|
||
audio = audio.float().cpu()
|
||
if audio.dim() == 2:
|
||
audio = audio.unsqueeze(1)
|
||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||
audio = audio.mean(dim=1, keepdim=True)
|
||
|
||
target_rms = 10 ** (-27.0 / 20.0) # -27 dBFS matches measured RMS of training clips
|
||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
||
audio = audio * (target_rms / rms)
|
||
peak = audio.abs().max().clamp(min=1e-8)
|
||
if peak > 1.0:
|
||
audio = audio / peak
|
||
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
|
||
|
||
except Exception as e:
|
||
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
|
||
return None, None
|
||
finally:
|
||
generator.train()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Eval spectrogram rendering
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_SPEC_N_FFT = 2048
|
||
_SPEC_HOP = 512
|
||
_SPEC_DB_FLOOR = -80.0
|
||
_SPEC_LOG_BINS = 256
|
||
|
||
|
||
def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
|
||
"""Compute spectral quality metrics for a mono [1, L] float32 CPU tensor.
|
||
|
||
Returns:
|
||
hf_energy_ratio — energy above 4kHz / total energy (low bitrate → low value)
|
||
spectral_centroid_hz — energy-weighted mean frequency
|
||
spectral_rolloff_hz — frequency below which 85% of energy sits
|
||
"""
|
||
import numpy as np
|
||
wav_np = wav.squeeze(0).numpy()
|
||
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||
window = torch.hann_window(_SPEC_N_FFT)
|
||
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||
window=window, return_complex=True)
|
||
power = stft.abs().pow(2).mean(dim=1).numpy() # [n_freqs] averaged over time
|
||
freqs = np.linspace(0, sr / 2, len(power))
|
||
|
||
total = power.sum() + 1e-12
|
||
hf_mask = freqs >= 4000
|
||
hf_ratio = float(power[hf_mask].sum() / total)
|
||
|
||
centroid = float((freqs * power).sum() / total)
|
||
|
||
cumsum = np.cumsum(power)
|
||
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
|
||
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
|
||
|
||
# Spectral flatness (Wiener entropy): geometric_mean / arithmetic_mean of power
|
||
# 0.0 = pure tone, 1.0 = white noise — rising value = noise contamination
|
||
log_power = np.log(power + 1e-12)
|
||
flatness = float(np.exp(log_power.mean()) / (power.mean() + 1e-12))
|
||
|
||
# Temporal energy variance — how dynamic the audio is
|
||
# Compute RMS per frame, take std. Low value = compressed/lifeless
|
||
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||
window = torch.hann_window(_SPEC_N_FFT)
|
||
stft_full = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||
window=window, return_complex=True)
|
||
frame_rms = stft_full.abs().pow(2).mean(dim=0).sqrt().numpy() # [n_frames]
|
||
temporal_variance = float(frame_rms.std() / (frame_rms.mean() + 1e-12))
|
||
|
||
return {
|
||
"hf_energy_ratio": round(hf_ratio, 4),
|
||
"spectral_centroid_hz": round(centroid, 1),
|
||
"spectral_rolloff_hz": round(rolloff, 1),
|
||
"spectral_flatness": round(flatness, 4),
|
||
"temporal_variance": round(temporal_variance, 4),
|
||
}
|
||
|
||
|
||
def _reference_metrics(gen_wav: torch.Tensor, ref_wav: torch.Tensor, sr: int) -> dict:
|
||
"""Compare generated eval sample against the original reference audio.
|
||
|
||
gen_wav, ref_wav: [1, L] float32 CPU tensors (mono).
|
||
|
||
Returns:
|
||
log_spectral_distance_db — RMS dB difference across freq bins (lower = better)
|
||
mel_cepstral_distortion — scaled L2 in log-mel space (lower = better)
|
||
per_band_correlation — mean Pearson r across mel bands (higher = better)
|
||
"""
|
||
from .selva_audio_preprocessors import _mel_filterbank
|
||
|
||
L = min(gen_wav.shape[-1], ref_wav.shape[-1])
|
||
gen = gen_wav[..., :L].squeeze(0)
|
||
ref = ref_wav[..., :L].squeeze(0)
|
||
|
||
n_fft = _SPEC_N_FFT
|
||
hop = _SPEC_HOP
|
||
window = torch.hann_window(n_fft)
|
||
|
||
gen_stft = torch.stft(gen, n_fft, hop, window=window, return_complex=True)
|
||
ref_stft = torch.stft(ref, n_fft, hop, window=window, return_complex=True)
|
||
|
||
gen_mag = gen_stft.abs()
|
||
ref_mag = ref_stft.abs()
|
||
|
||
T = min(gen_mag.shape[1], ref_mag.shape[1])
|
||
gen_mag = gen_mag[:, :T]
|
||
ref_mag = ref_mag[:, :T]
|
||
|
||
# Log-spectral distance (dB)
|
||
gen_db = 20.0 * torch.log10(gen_mag.clamp(min=1e-8))
|
||
ref_db = 20.0 * torch.log10(ref_mag.clamp(min=1e-8))
|
||
lsd = float(((gen_db - ref_db) ** 2).mean().sqrt())
|
||
|
||
# Mel-scale comparison
|
||
n_mels = 80
|
||
fb = _mel_filterbank(sr, n_fft, n_mels, 0, sr // 2) # [n_mels, n_freqs]
|
||
|
||
gen_mel = torch.matmul(fb, gen_mag).clamp(min=1e-8)
|
||
ref_mel = torch.matmul(fb, ref_mag).clamp(min=1e-8)
|
||
|
||
gen_mel_log = torch.log(gen_mel)
|
||
ref_mel_log = torch.log(ref_mel)
|
||
|
||
# Mel cepstral distortion (L2 in log-mel space, standard scaling)
|
||
mcd = float(
|
||
(10.0 / np.log(10))
|
||
* ((gen_mel_log - ref_mel_log) ** 2).mean(0).sqrt().mean()
|
||
)
|
||
|
||
# Per-band Pearson correlation
|
||
gen_np = gen_mel_log.numpy()
|
||
ref_np = ref_mel_log.numpy()
|
||
correlations = []
|
||
for b in range(n_mels):
|
||
g, r = gen_np[b], ref_np[b]
|
||
if g.std() < 1e-8 or r.std() < 1e-8:
|
||
continue
|
||
corr = np.corrcoef(g, r)[0, 1]
|
||
if np.isfinite(corr):
|
||
correlations.append(corr)
|
||
mean_corr = float(np.mean(correlations)) if correlations else 0.0
|
||
|
||
return {
|
||
"log_spectral_distance_db": round(lsd, 2),
|
||
"mel_cepstral_distortion": round(mcd, 2),
|
||
"per_band_correlation": round(mean_corr, 4),
|
||
}
|
||
|
||
|
||
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
|
||
"""Save a log-frequency dB spectrogram PNG for an eval sample.
|
||
|
||
wav: [1, L] float32 CPU tensor (mono).
|
||
"""
|
||
import numpy as np
|
||
from matplotlib.figure import Figure
|
||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||
|
||
wav_np = wav.squeeze(0).numpy()
|
||
hop = min(_SPEC_HOP, _SPEC_N_FFT)
|
||
window = torch.hann_window(_SPEC_N_FFT)
|
||
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
|
||
window=window, return_complex=True)
|
||
mag = stft.abs().numpy()
|
||
db = 20.0 * np.log10(np.maximum(mag, 1e-8))
|
||
db = np.maximum(db, db.max() + _SPEC_DB_FLOOR).astype(np.float32)
|
||
|
||
# Log-frequency resampling
|
||
n_freqs = db.shape[0]
|
||
src_idx = np.logspace(0, np.log10(max(n_freqs - 1, 2)), _SPEC_LOG_BINS)
|
||
lo = np.floor(src_idx).astype(int).clip(0, n_freqs - 2)
|
||
frac = (src_idx - lo)[:, None]
|
||
spec = ((1 - frac) * db[lo] + frac * db[lo + 1]).astype(np.float32)
|
||
spec = spec[::-1] # low freq at bottom
|
||
|
||
# Y-tick positions (Hz labels)
|
||
tgt_hz = [100, 500, 1000, 2000, 4000, 8000, 16000]
|
||
tpos, tlbl = [], []
|
||
for hz in tgt_hz:
|
||
bin_f = hz * _SPEC_N_FFT / sr
|
||
if bin_f < 1 or bin_f >= n_freqs:
|
||
continue
|
||
pos = int(np.searchsorted(src_idx, bin_f))
|
||
tpos.append(_SPEC_LOG_BINS - 1 - min(pos, _SPEC_LOG_BINS - 1))
|
||
tlbl.append(f"{hz // 1000}k" if hz >= 1000 else str(hz))
|
||
|
||
vmin = float(np.percentile(spec, 2.0))
|
||
vmax = float(np.percentile(spec, 99.5))
|
||
|
||
fig = Figure(figsize=(12, 3), dpi=120, tight_layout=True)
|
||
ax = fig.add_subplot(1, 1, 1)
|
||
im = ax.imshow(spec, aspect="auto", cmap="inferno", origin="upper",
|
||
vmin=vmin, vmax=vmax, interpolation="antialiased")
|
||
ax.set_yticks(tpos)
|
||
ax.set_yticklabels(tlbl, fontsize=8)
|
||
ax.set_ylabel("Hz", fontsize=9)
|
||
ax.set_xlabel("Time frames", fontsize=9)
|
||
ax.set_title(path.stem, fontsize=9)
|
||
fig.colorbar(im, ax=ax, label="dB", fraction=0.02, pad=0.01)
|
||
|
||
canvas = FigureCanvasAgg(fig)
|
||
canvas.draw()
|
||
canvas.print_figure(str(path.with_suffix(".png")), dpi=120)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Loss curve rendering
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _smooth_losses(losses: list[float], beta: float = 0.9) -> list[float]:
|
||
"""Exponential moving average smoothing."""
|
||
smoothed, ema = [], None
|
||
for v in losses:
|
||
ema = v if ema is None else beta * ema + (1 - beta) * v
|
||
smoothed.append(ema)
|
||
return smoothed
|
||
|
||
|
||
def _draw_loss_curve(losses: list[float], log_interval: int,
|
||
start_step: int = 0, smoothed: list[float] | None = None) -> Image.Image:
|
||
"""Render a loss curve as a PIL Image."""
|
||
W, H = 800, 380
|
||
pl, pr, pt, pb = 70, 20, 25, 45
|
||
|
||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
||
draw = ImageDraw.Draw(img)
|
||
|
||
pw = W - pl - pr
|
||
ph = H - pt - pb
|
||
|
||
if len(losses) >= 2:
|
||
lo, hi = min(losses), max(losses)
|
||
if hi == lo:
|
||
hi = lo + 1e-6
|
||
rng = hi - lo
|
||
|
||
# Horizontal grid + y-axis labels
|
||
for i in range(5):
|
||
y = pt + int(i * ph / 4)
|
||
val = hi - i * rng / 4
|
||
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
||
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
|
||
|
||
# Raw loss line
|
||
n = len(losses)
|
||
pts = []
|
||
for i, v in enumerate(losses):
|
||
x = pl + int(i * pw / max(n - 1, 1))
|
||
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
||
pts.append((x, y))
|
||
draw.line(pts, fill=(200, 220, 255), width=1)
|
||
|
||
# Smoothed overlay
|
||
if smoothed is not None and len(smoothed) >= 2:
|
||
spts = []
|
||
for i, v in enumerate(smoothed):
|
||
x = pl + int(i * pw / max(n - 1, 1))
|
||
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
||
spts.append((x, y))
|
||
draw.line(spts, fill=(66, 133, 244), width=2)
|
||
|
||
# x-axis step labels — account for start_step so resumed runs are correct
|
||
first_step = start_step + log_interval
|
||
last_step = start_step + n * log_interval
|
||
for i in range(5):
|
||
x = pl + int(i * pw / 4)
|
||
step = int(first_step + i * (last_step - first_step) / 4)
|
||
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
|
||
|
||
# Axes
|
||
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
||
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
|
||
|
||
return img
|
||
|
||
|
||
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
|
||
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
||
arr = np.array(img).astype(np.float32) / 255.0
|
||
return torch.from_numpy(arr).unsqueeze(0)
|
||
|
||
|
||
def _prepare_dataset(model: dict, data_dir: Path, device) -> list:
|
||
"""Load VAE, encode audio clips, load .npz features.
|
||
|
||
Returns a list of (latents, clip_features, sync_features, text_clip) CPU tensors.
|
||
The VAE is freed after encoding. Call this once and reuse the dataset across
|
||
multiple training jobs (e.g. in the scheduler).
|
||
"""
|
||
mode = model["mode"]
|
||
seq_cfg = model["seq_cfg"]
|
||
feature_utils_orig = model["feature_utils"]
|
||
|
||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||
vae_path = _SELVA_DIR / "ext" / vae_name
|
||
if not vae_path.exists():
|
||
raise FileNotFoundError(
|
||
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
||
"Run SelVA Model Loader first to auto-download weights."
|
||
)
|
||
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||
vae_utils = FeaturesUtils(
|
||
tod_vae_ckpt=str(vae_path),
|
||
enable_conditions=False,
|
||
mode=mode,
|
||
need_vae_encoder=True,
|
||
).to(device).eval()
|
||
|
||
npz_files = sorted(data_dir.glob("*.npz"))
|
||
if not npz_files:
|
||
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
||
|
||
prompt_map = _load_prompts(data_dir)
|
||
default_prompt = data_dir.name
|
||
|
||
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
||
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
||
dataset = []
|
||
|
||
for npz_path in npz_files:
|
||
audio_path = _find_audio(npz_path)
|
||
if audio_path is None:
|
||
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
||
pbar_load.update(1)
|
||
continue
|
||
|
||
bundle = _load_npz(npz_path)
|
||
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
||
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
||
|
||
try:
|
||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||
|
||
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||
# encode_audio is @inference_mode — .clone() exits inference mode
|
||
audio_b = audio.unsqueeze(0).to(device)
|
||
dist = vae_utils.encode_audio(audio_b)
|
||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||
tgt = seq_cfg.latent_seq_len
|
||
if x1.shape[1] < tgt:
|
||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||
elif x1.shape[1] > tgt:
|
||
x1 = x1[:, :tgt, :]
|
||
|
||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||
|
||
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
||
# shorter videos have fewer frames and would cause stack() to fail
|
||
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||
c_tgt = seq_cfg.clip_seq_len
|
||
if clip_f.shape[1] < c_tgt:
|
||
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||
elif clip_f.shape[1] > c_tgt:
|
||
clip_f = clip_f[:, :c_tgt, :]
|
||
|
||
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||
s_tgt = seq_cfg.sync_seq_len
|
||
if sync_f.shape[1] < s_tgt:
|
||
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||
elif sync_f.shape[1] > s_tgt:
|
||
sync_f = sync_f[:, :s_tgt, :]
|
||
|
||
dataset.append((x1, clip_f, sync_f, text_clip))
|
||
except Exception as e:
|
||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||
traceback.print_exc()
|
||
|
||
pbar_load.update(1)
|
||
|
||
# VAE no longer needed — free memory
|
||
del vae_utils
|
||
soft_empty_cache()
|
||
|
||
if not dataset:
|
||
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||
|
||
return dataset
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Node
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class SelvaLoraTrainer:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"model": ("SELVA_MODEL",),
|
||
"data_dir": ("STRING", {
|
||
"default": "",
|
||
"tooltip": "Directory containing .npz feature files and paired audio files.",
|
||
}),
|
||
"output_dir": ("STRING", {
|
||
"default": "lora_output",
|
||
"tooltip": "Where to save adapter checkpoints.",
|
||
}),
|
||
"steps": ("INT", {
|
||
"default": 2000, "min": 100, "max": 100000,
|
||
"tooltip": "Total training steps.",
|
||
}),
|
||
"rank": ("INT", {
|
||
"default": 16, "min": 1, "max": 128,
|
||
"tooltip": "LoRA rank. Higher = more capacity, more VRAM. 16 is a safe default.",
|
||
}),
|
||
"lr": ("FLOAT", {
|
||
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-6,
|
||
"tooltip": "Learning rate.",
|
||
}),
|
||
},
|
||
"optional": {
|
||
"alpha": ("FLOAT", {
|
||
"default": 0.0, "min": 0.0, "max": 512.0, "step": 0.5,
|
||
"tooltip": "LoRA alpha. 0 = use 2×rank (fewer intruder dimensions, "
|
||
"arXiv:2410.21228). Set explicitly to override.",
|
||
}),
|
||
"target": ("STRING", {
|
||
"default": "attn.qkv",
|
||
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
|
||
}),
|
||
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
|
||
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
|
||
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
|
||
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
|
||
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
|
||
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||
"resume_path": ("STRING", {
|
||
"default": "",
|
||
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||
}),
|
||
"seed": ("INT", {"default": 42}),
|
||
"timestep_mode": (["uniform", "logit_normal", "curriculum"], {
|
||
"default": "uniform",
|
||
"tooltip": "How to sample training timesteps. "
|
||
"uniform: all timesteps equally (matches original MMAudio). "
|
||
"logit_normal: concentrates near t=0.5. "
|
||
"curriculum: logit_normal for first curriculum_switch% of steps then uniform (recommended for small datasets).",
|
||
}),
|
||
"logit_normal_sigma": ("FLOAT", {
|
||
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
||
"tooltip": "Spread of the logit-normal distribution. "
|
||
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
|
||
"Used with logit_normal and curriculum modes.",
|
||
}),
|
||
"curriculum_switch": ("FLOAT", {
|
||
"default": 0.6, "min": 0.1, "max": 0.9, "step": 0.05,
|
||
"tooltip": "Fraction of steps to run logit_normal before switching to uniform. "
|
||
"0.6 = switch at 60% of total steps. Only used with timestep_mode=curriculum.",
|
||
}),
|
||
"lora_dropout": ("FLOAT", {
|
||
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
|
||
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
|
||
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610). "
|
||
"Forced to 0 when using PiSSA init.",
|
||
}),
|
||
"lora_plus_ratio": ("FLOAT", {
|
||
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
|
||
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
||
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
||
}),
|
||
"init_mode": (["standard", "pissa"], {
|
||
"default": "pissa",
|
||
"tooltip": "LoRA initialization mode. "
|
||
"standard: Kaiming-uniform A + zero B (classic LoRA). "
|
||
"pissa: A and B from top-r SVD of pretrained weight — starts "
|
||
"on-manifold, eliminates intruder dimensions (arXiv:2404.02948). "
|
||
"Recommended for audio generation where off-manifold latents cause noise.",
|
||
}),
|
||
"use_rslora": ("BOOLEAN", {
|
||
"default": True,
|
||
"tooltip": "Rank-stabilized LoRA scaling: alpha/sqrt(rank) instead of alpha/rank. "
|
||
"Prevents gradient collapse at high ranks (128+). Recommended (arXiv:2312.03732).",
|
||
}),
|
||
"lr_schedule": (["constant", "cosine"], {
|
||
"default": "constant",
|
||
"tooltip": "LR schedule after warmup. "
|
||
"constant: flat LR for all steps. "
|
||
"cosine: decay from lr to ~0 following a cosine curve — "
|
||
"prevents oscillation when LR is slightly too high.",
|
||
}),
|
||
"latent_mixup_alpha": ("FLOAT", {
|
||
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
||
"tooltip": "Beta distribution alpha for latent mixup (MusicLDM, arXiv:2308.01546). "
|
||
"0 = disabled. 0.4 recommended. Mixes pairs of training latents "
|
||
"to reduce memorization on small datasets.",
|
||
}),
|
||
"latent_noise_sigma": ("FLOAT", {
|
||
"default": 0.0, "min": 0.0, "max": 0.1, "step": 0.005,
|
||
"tooltip": "Additive Gaussian noise on training latents, scaled by x1.std(). "
|
||
"0 = disabled. 0.01–0.03 adds mild regularization against overfitting.",
|
||
}),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("SELVA_MODEL", "STRING", "IMAGE")
|
||
RETURN_NAMES = ("model", "adapter_path", "loss_curve")
|
||
OUTPUT_TOOLTIPS = (
|
||
"Model with trained LoRA adapter applied — connect directly to Sampler.",
|
||
"Path to adapter_final.pt — use with SelVA LoRA Loader in future sessions.",
|
||
"Training loss curve.",
|
||
)
|
||
FUNCTION = "train"
|
||
CATEGORY = SELVA_CATEGORY
|
||
DESCRIPTION = (
|
||
"Trains a LoRA adapter on a dataset of .npz feature files + paired audio files. "
|
||
"Blocks the queue for the duration of training. "
|
||
"Prepare the dataset with SelVA Feature Extractor (set a name to get numbered .npz files) "
|
||
"and pair each .npz with a clean audio file of the same stem."
|
||
)
|
||
|
||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||
init_mode="pissa", use_rslora=True, lr_schedule="constant",
|
||
latent_mixup_alpha=0.0, latent_noise_sigma=0.0):
|
||
|
||
torch.manual_seed(seed)
|
||
random.seed(seed)
|
||
|
||
device = get_device()
|
||
dtype = model["dtype"]
|
||
variant = model["variant"]
|
||
mode = model["mode"]
|
||
seq_cfg = model["seq_cfg"]
|
||
feature_utils_orig = model["feature_utils"]
|
||
|
||
data_dir = Path(data_dir.strip())
|
||
|
||
_out_str = output_dir.strip()
|
||
_out_p = Path(_out_str)
|
||
# On Windows a Unix-style path like "/lora_output" is technically absolute
|
||
# (drive-relative) but the user almost certainly meant a subfolder of the
|
||
# ComfyUI output directory. Treat any non-absolute path AND any path whose
|
||
# only "absolute" anchor is a leading slash (no drive letter) as relative to
|
||
# the ComfyUI output folder.
|
||
import sys as _sys
|
||
_unix_style_on_windows = (
|
||
_sys.platform == "win32"
|
||
and _out_p.is_absolute()
|
||
and not _out_p.drive # e.g. Path("/foo").drive == "" on Windows
|
||
)
|
||
if not _out_p.is_absolute() or _unix_style_on_windows:
|
||
_out_p = Path(folder_paths.get_output_directory()) / _out_p.relative_to(_out_p.anchor)
|
||
print(f"[LoRA Trainer] output_dir resolved to: {_out_p}", flush=True)
|
||
output_dir = _out_p
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
alpha_val = float(alpha) if alpha > 0.0 else float(2 * rank)
|
||
target_suffixes = tuple(target.strip().split())
|
||
|
||
dataset = _prepare_dataset(model, data_dir, device)
|
||
|
||
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
|
||
# can't participate in autograd even with enable_grad — disable inference
|
||
# mode entirely so deepcopy, apply_lora, and the training loop all run
|
||
# with a clean autograd context.
|
||
with torch.inference_mode(False), torch.enable_grad():
|
||
r = self._train_inner(
|
||
model, dataset, feature_utils_orig, seq_cfg,
|
||
device, dtype, variant, mode,
|
||
data_dir, output_dir, steps, rank, lr,
|
||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||
grad_accum, save_every, resume_path, seed,
|
||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||
lora_dropout, lora_plus_ratio, lr_schedule,
|
||
init_mode, use_rslora,
|
||
latent_mixup_alpha, latent_noise_sigma,
|
||
)
|
||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||
|
||
def _train_inner(
|
||
self, model, dataset, feature_utils_orig, seq_cfg,
|
||
device, dtype, variant, mode,
|
||
data_dir, output_dir, steps, rank, lr,
|
||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||
grad_accum, save_every, resume_path, seed,
|
||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
||
init_mode="pissa", use_rslora=True,
|
||
latent_mixup_alpha=0.0, latent_noise_sigma=0.0,
|
||
):
|
||
# --- Prepare generator copy with LoRA ---
|
||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||
|
||
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||
target_suffixes=target_suffixes, dropout=lora_dropout,
|
||
init_mode=init_mode, use_rslora=use_rslora)
|
||
if n_lora == 0:
|
||
raise RuntimeError(
|
||
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||
"Check the 'target' field."
|
||
)
|
||
scale_str = f"alpha/√rank={alpha_val/math.sqrt(rank):.2f}" if use_rslora \
|
||
else f"alpha/rank={alpha_val/rank:.2f}"
|
||
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
|
||
f"(rank={rank}, alpha={alpha_val}, {scale_str}, "
|
||
f"init={init_mode}, dropout={lora_dropout})", flush=True)
|
||
|
||
for name, p in generator.named_parameters():
|
||
p.requires_grad_("lora_" in name)
|
||
|
||
generator.update_seq_lengths(
|
||
latent_seq_len=seq_cfg.latent_seq_len,
|
||
clip_seq_len=seq_cfg.clip_seq_len,
|
||
sync_seq_len=seq_cfg.sync_seq_len,
|
||
)
|
||
|
||
# --- Optimizer + scheduler ---
|
||
# LoRA+: split A and B into separate param groups with different LRs.
|
||
# ratio=1.0 = standard LoRA (same LR for both). ratio=16 = LoRA+.
|
||
lora_A_params = [p for n, p in generator.named_parameters() if "lora_A" in n and p.requires_grad]
|
||
lora_B_params = [p for n, p in generator.named_parameters() if "lora_B" in n and p.requires_grad]
|
||
optimizer = torch.optim.AdamW([
|
||
{"params": lora_A_params, "lr": lr},
|
||
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
|
||
], weight_decay=0.0)
|
||
if lora_plus_ratio != 1.0:
|
||
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
||
|
||
if lr_schedule == "cosine":
|
||
def lr_lambda(s):
|
||
if s < warmup_steps:
|
||
return s / max(1, warmup_steps)
|
||
progress = (s - warmup_steps) / max(1, steps - warmup_steps)
|
||
return max(1e-6 / lr, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||
print(f"[LoRA Trainer] LR schedule: cosine decay {lr:.2e} → 0", flush=True)
|
||
else:
|
||
def lr_lambda(s):
|
||
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
||
|
||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||
|
||
# --- Resume ---
|
||
start_step = 0
|
||
if resume_path.strip():
|
||
ckpt = torch.load(resume_path.strip(), map_location="cpu", weights_only=False)
|
||
if "step" not in ckpt:
|
||
raise ValueError("[LoRA Trainer] Checkpoint has no step info.")
|
||
start_step = ckpt["step"]
|
||
if start_step >= steps:
|
||
raise ValueError(
|
||
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
|
||
)
|
||
load_lora(generator, ckpt["state_dict"])
|
||
optimizer.load_state_dict(ckpt["optimizer"])
|
||
scheduler.load_state_dict(ckpt["scheduler"])
|
||
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
|
||
|
||
# --- Training loop ---
|
||
generator.train()
|
||
optimizer.zero_grad()
|
||
|
||
log_interval = 50
|
||
remaining = steps - start_step
|
||
if remaining < log_interval:
|
||
raise ValueError(
|
||
f"[LoRA Trainer] Only {remaining} steps remaining (steps={steps}, "
|
||
f"start_step={start_step}). Need at least {log_interval} steps to "
|
||
"record any loss — increase 'steps' or lower the resume checkpoint."
|
||
)
|
||
pbar_train = comfy.utils.ProgressBar(remaining)
|
||
loss_history = []
|
||
running_loss = 0.0
|
||
grad_norm_history = []
|
||
spectral_metrics = {} # {step: {hf_energy_ratio, spectral_centroid_hz, spectral_rolloff_hz}}
|
||
running_grad_norm = 0.0
|
||
grad_norm_count = 0
|
||
|
||
meta = {
|
||
"variant": variant,
|
||
"rank": rank,
|
||
"alpha": alpha_val,
|
||
"target": list(target_suffixes),
|
||
"steps": steps,
|
||
"timestep_mode": timestep_mode,
|
||
"logit_normal_sigma": logit_normal_sigma,
|
||
"curriculum_switch": curriculum_switch,
|
||
"lora_dropout": lora_dropout,
|
||
"lora_plus_ratio": lora_plus_ratio,
|
||
"lr_schedule": lr_schedule,
|
||
"init_mode": init_mode,
|
||
"use_rslora": use_rslora,
|
||
"latent_mixup_alpha": latent_mixup_alpha,
|
||
"latent_noise_sigma": latent_noise_sigma,
|
||
}
|
||
|
||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||
curriculum_switch_step = start_step + int((steps - start_step) * curriculum_switch)
|
||
_curriculum_switched = False
|
||
|
||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||
f"timestep_mode={timestep_mode})\n", flush=True)
|
||
|
||
# Load reference audio for eval comparison against original
|
||
ref_wav = None
|
||
try:
|
||
npz_files = sorted(data_dir.glob("*.npz"))
|
||
if npz_files:
|
||
ref_audio_path = _find_audio(npz_files[0])
|
||
if ref_audio_path is not None:
|
||
ref_wav = _load_audio(ref_audio_path, seq_cfg.sampling_rate,
|
||
seq_cfg.duration).unsqueeze(0) # [1, L]
|
||
print(f"[LoRA Trainer] Reference audio: {ref_audio_path.name}", flush=True)
|
||
except Exception as e:
|
||
print(f"[LoRA Trainer] Could not load reference audio: {e}", flush=True)
|
||
|
||
last_step = start_step
|
||
completed = False
|
||
try:
|
||
for step in range(start_step + 1, steps + 1):
|
||
if batch_size <= len(dataset):
|
||
batch = random.sample(dataset, k=batch_size)
|
||
else:
|
||
batch = random.choices(dataset, k=batch_size)
|
||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||
|
||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||
|
||
x1 = generator.normalize(x1)
|
||
|
||
# Latent mixup (MusicLDM, arXiv:2308.01546)
|
||
if latent_mixup_alpha > 0 and x1.shape[0] > 1:
|
||
lam = torch.distributions.Beta(
|
||
latent_mixup_alpha, latent_mixup_alpha
|
||
).sample().to(device)
|
||
idx = torch.randperm(x1.shape[0], device=device)
|
||
x1 = lam * x1 + (1 - lam) * x1[idx]
|
||
|
||
# Latent noise injection
|
||
if latent_noise_sigma > 0:
|
||
x1 = x1 + torch.randn_like(x1) * latent_noise_sigma * x1.std()
|
||
|
||
if timestep_mode == "logit_normal" or (
|
||
timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||
):
|
||
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
|
||
t = torch.sigmoid(u)
|
||
else:
|
||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||
|
||
if timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
|
||
print(f"[LoRA Trainer] Curriculum switch: logit_normal → uniform at step {step}", flush=True)
|
||
_curriculum_switched = True
|
||
x0 = torch.randn_like(x1)
|
||
xt = fm.get_conditional_flow(x0, x1, t)
|
||
|
||
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
||
loss.backward()
|
||
running_loss += loss.item() * grad_accum
|
||
|
||
if step % grad_accum == 0:
|
||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||
lora_A_params + lora_B_params, max_norm=1.0
|
||
).item()
|
||
running_grad_norm += grad_norm
|
||
grad_norm_count += 1
|
||
optimizer.step()
|
||
scheduler.step()
|
||
optimizer.zero_grad()
|
||
|
||
if step % log_interval == 0:
|
||
skip_flag = output_dir.parent / "skip_current.flag"
|
||
if skip_flag.exists():
|
||
skip_flag.unlink()
|
||
exc = SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment")
|
||
exc.partial = {
|
||
"loss_history": list(loss_history),
|
||
"grad_norm_history": list(grad_norm_history),
|
||
"spectral_metrics": dict(spectral_metrics),
|
||
"stopped_at_step": step,
|
||
}
|
||
raise exc
|
||
|
||
avg = running_loss / log_interval
|
||
loss_history.append(avg)
|
||
# grad_norm_count can be 0 when grad_accum > log_interval
|
||
# (no optimizer step fired in this interval yet)
|
||
if grad_norm_count > 0:
|
||
avg_gnorm = running_grad_norm / grad_norm_count
|
||
grad_norm_history.append(round(avg_gnorm, 6))
|
||
gnorm_str = f" grad_norm={avg_gnorm:.4f}"
|
||
else:
|
||
grad_norm_history.append(None)
|
||
gnorm_str = ""
|
||
lr_now = scheduler.get_last_lr()[0]
|
||
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
||
f"loss={avg:.4f}{gnorm_str} "
|
||
f"lr={lr_now:.2e} bs={batch_size}", flush=True)
|
||
running_loss = 0.0
|
||
running_grad_norm = 0.0
|
||
grad_norm_count = 0
|
||
|
||
# Live preview: send updated loss curve to ComfyUI frontend
|
||
preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||
smoothed=_smooth_losses(loss_history))
|
||
pbar_train.update_absolute(
|
||
step - start_step, remaining, ("JPEG", preview_img, 800)
|
||
)
|
||
|
||
if step % save_every == 0 or step == steps:
|
||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||
# PiSSA checkpoints need base weights (residual W_res)
|
||
sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \
|
||
else get_lora_state_dict(generator)
|
||
torch.save({
|
||
"state_dict": sd,
|
||
"optimizer": optimizer.state_dict(),
|
||
"scheduler": scheduler.state_dict(),
|
||
"step": step,
|
||
"meta": meta,
|
||
}, ckpt_path)
|
||
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||
|
||
# Save a quick eval sample in samples/ subfolder
|
||
samples_dir = output_dir / "samples"
|
||
samples_dir.mkdir(exist_ok=True)
|
||
wav, sr = _eval_sample(generator, feature_utils_orig,
|
||
dataset, seq_cfg, device, dtype, seed=seed)
|
||
if wav is not None:
|
||
wav_path = samples_dir / f"step_{step:05d}.wav"
|
||
try:
|
||
torchaudio.save(str(wav_path), wav, sr)
|
||
except RuntimeError:
|
||
import soundfile as sf
|
||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||
try:
|
||
metrics = _spectral_metrics(wav, sr)
|
||
spectral_metrics[step] = metrics
|
||
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
|
||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz "
|
||
f"flatness={metrics['spectral_flatness']:.3f} "
|
||
f"temporal_var={metrics['temporal_variance']:.3f}", flush=True)
|
||
_save_spectrogram(wav, sr, wav_path)
|
||
if ref_wav is not None:
|
||
ref_m = _reference_metrics(wav, ref_wav, sr)
|
||
spectral_metrics[step].update(ref_m)
|
||
print(f"[LoRA Trainer] vs Reference: "
|
||
f"LSD={ref_m['log_spectral_distance_db']:.1f}dB "
|
||
f"MCD={ref_m['mel_cepstral_distortion']:.2f} "
|
||
f"corr={ref_m['per_band_correlation']:.3f}", flush=True)
|
||
except Exception as e:
|
||
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
||
|
||
last_step = step
|
||
pbar_train.update(1)
|
||
|
||
completed = True
|
||
|
||
# ── Post-training Spectral Surgery ────────────────────────────────
|
||
# Reweight LoRA singular values using gradient sensitivity on the
|
||
# training set. Suppresses intruder dimensions, amplifies useful ones.
|
||
# (arXiv:2603.03995). Only run on normal completion.
|
||
try:
|
||
print("[LoRA Trainer] Running Spectral Surgery...", flush=True)
|
||
fm_surgery = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||
|
||
def _calibration_fn(model_cal, step_idx):
|
||
sample = dataset[step_idx % len(dataset)]
|
||
x1_cal, clip_cal, sync_cal, text_cal = sample
|
||
x1_b = x1_cal.unsqueeze(0).to(device, dtype) if x1_cal.dim() == 2 \
|
||
else x1_cal.to(device, dtype)
|
||
x1_b = model_cal.normalize(x1_b.clone())
|
||
clip_b = clip_cal.to(device, dtype)
|
||
sync_b = sync_cal.to(device, dtype)
|
||
text_b = text_cal.to(device, dtype)
|
||
t = torch.rand(1, device=device, dtype=dtype)
|
||
x0_b = torch.randn_like(x1_b)
|
||
xt = fm_surgery.get_conditional_flow(x0_b, x1_b, t)
|
||
v_pred = model_cal.forward(xt, clip_b, sync_b, text_b, t)
|
||
cal_loss = fm_surgery.loss(v_pred, x0_b, x1_b).mean()
|
||
cal_loss.backward()
|
||
|
||
n_cal = min(128, len(dataset) * 4)
|
||
n_surgery = spectral_surgery(generator, _calibration_fn,
|
||
n_calibration=n_cal)
|
||
print(f"[LoRA Trainer] Spectral Surgery done: {n_surgery} layers processed.",
|
||
flush=True)
|
||
except Exception as e:
|
||
print(f"[LoRA Trainer] Spectral Surgery failed (non-fatal): {e}", flush=True)
|
||
|
||
finally:
|
||
# Save adapter and loss curves whether training completed or was cancelled.
|
||
# Skip if we never completed a single step (nothing useful to save).
|
||
if loss_history:
|
||
if completed:
|
||
# Normal completion — use adapter_final.pt (increment if exists)
|
||
final_path = output_dir / "adapter_final.pt"
|
||
if final_path.exists():
|
||
i = 1
|
||
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
||
i += 1
|
||
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
||
label = "Done"
|
||
else:
|
||
# Cancelled — include the step number so the file is useful for resume
|
||
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||
label = f"Cancelled at step {last_step}"
|
||
|
||
final_sd = get_lora_and_base_state_dict(generator) if init_mode == "pissa" \
|
||
else get_lora_state_dict(generator)
|
||
torch.save({"state_dict": final_sd, "meta": meta}, final_path)
|
||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||
|
||
smoothed = _smooth_losses(loss_history)
|
||
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||
smoothed=smoothed)
|
||
raw_img.save(str(output_dir / "loss_raw.png"))
|
||
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
||
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
||
|
||
# Reached only on normal completion (exception re-raises past this point)
|
||
generator.eval()
|
||
generator.to(next(model["generator"].parameters()).device)
|
||
patched = {**model, "generator": generator}
|
||
|
||
loss_curve = _pil_to_tensor(smoothed_img)
|
||
return {
|
||
"patched_model": patched,
|
||
"adapter_path": str(final_path),
|
||
"loss_curve": loss_curve,
|
||
"loss_history": loss_history,
|
||
"grad_norm_history": grad_norm_history,
|
||
"spectral_metrics": spectral_metrics,
|
||
"start_step": start_step,
|
||
"meta": meta,
|
||
"completed": True,
|
||
}
|