Files
ComfyUI-SelVA/nodes/selva_lora_trainer.py
T
Ethanfel 678c050f11 fix: make normalize(x1) assignment explicit in training loop
normalize() uses in-place ops so it worked, but reading the return value
makes the intent clear and guards against future refactors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 15:43:42 +02:00

901 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import json
import math
import random
import traceback
from pathlib import Path
class SkipExperiment(Exception):
"""Raised when skip_current.flag is found — signals the scheduler to move to the next experiment."""
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
# ---------------------------------------------------------------------------
# Data helpers (mirror train_lora.py)
# ---------------------------------------------------------------------------
def _load_prompts(data_dir: Path) -> dict:
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def _find_audio(npz_path: Path) -> Path | None:
for ext in _AUDIO_EXTS:
c = npz_path.with_suffix(ext)
if c.exists():
return c
return None
def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
try:
waveform, sr = torchaudio.load(str(path))
except RuntimeError as e:
if "torchcodec" not in str(e).lower() and "libtorchcodec" not in str(e).lower():
raise
# torchcodec unavailable (FFmpeg shared libs missing) — fall back to soundfile
import soundfile as sf
data, sr = sf.read(str(path), always_2d=True) # [frames, channels]
waveform = torch.from_numpy(data.T).float() # [channels, frames]
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def _load_npz(path: Path) -> dict:
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]),
"sync_features": torch.from_numpy(data["sync_features"]),
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
return bundle
# ---------------------------------------------------------------------------
# Eval sample
# ---------------------------------------------------------------------------
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
num_steps: int = 25, seed: int = 42):
"""Run a quick no-CFG inference pass on a fixed training clip.
Always uses dataset[0] and a fixed noise seed so samples across checkpoints
are directly comparable — you can hear the model improve step by step.
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
Uses fewer ODE steps than inference (8 vs 25) for speed.
"""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[0]
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
rng = torch.Generator(device=device).manual_seed(seed)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_clip,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
orig_device = next(feature_utils_orig.parameters()).device
if orig_device != device:
feature_utils_orig.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
if orig_device != device:
feature_utils_orig.to(orig_device)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
target_rms = 10 ** (-23.0 / 20.0) # -23 dBFS matches training data
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
except Exception as e:
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Eval spectrogram rendering
# ---------------------------------------------------------------------------
_SPEC_N_FFT = 2048
_SPEC_HOP = 512
_SPEC_DB_FLOOR = -80.0
_SPEC_LOG_BINS = 256
def _spectral_metrics(wav: torch.Tensor, sr: int) -> dict:
"""Compute spectral quality metrics for a mono [1, L] float32 CPU tensor.
Returns:
hf_energy_ratio — energy above 4kHz / total energy (low bitrate → low value)
spectral_centroid_hz — energy-weighted mean frequency
spectral_rolloff_hz — frequency below which 85% of energy sits
"""
import numpy as np
wav_np = wav.squeeze(0).numpy()
hop = min(_SPEC_HOP, _SPEC_N_FFT)
window = torch.hann_window(_SPEC_N_FFT)
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
window=window, return_complex=True)
power = stft.abs().pow(2).mean(dim=1).numpy() # [n_freqs] averaged over time
freqs = np.linspace(0, sr / 2, len(power))
total = power.sum() + 1e-12
hf_mask = freqs >= 4000
hf_ratio = float(power[hf_mask].sum() / total)
centroid = float((freqs * power).sum() / total)
cumsum = np.cumsum(power)
rolloff_idx = np.searchsorted(cumsum, 0.85 * cumsum[-1])
rolloff = float(freqs[min(rolloff_idx, len(freqs) - 1)])
# Spectral flatness (Wiener entropy): geometric_mean / arithmetic_mean of power
# 0.0 = pure tone, 1.0 = white noise — rising value = noise contamination
log_power = np.log(power + 1e-12)
flatness = float(np.exp(log_power.mean()) / (power.mean() + 1e-12))
# Temporal energy variance — how dynamic the audio is
# Compute RMS per frame, take std. Low value = compressed/lifeless
hop = min(_SPEC_HOP, _SPEC_N_FFT)
window = torch.hann_window(_SPEC_N_FFT)
stft_full = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
window=window, return_complex=True)
frame_rms = stft_full.abs().pow(2).mean(dim=0).sqrt().numpy() # [n_frames]
temporal_variance = float(frame_rms.std() / (frame_rms.mean() + 1e-12))
return {
"hf_energy_ratio": round(hf_ratio, 4),
"spectral_centroid_hz": round(centroid, 1),
"spectral_rolloff_hz": round(rolloff, 1),
"spectral_flatness": round(flatness, 4),
"temporal_variance": round(temporal_variance, 4),
}
def _save_spectrogram(wav: torch.Tensor, sr: int, path: Path) -> None:
"""Save a log-frequency dB spectrogram PNG for an eval sample.
wav: [1, L] float32 CPU tensor (mono).
"""
import numpy as np
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
wav_np = wav.squeeze(0).numpy()
hop = min(_SPEC_HOP, _SPEC_N_FFT)
window = torch.hann_window(_SPEC_N_FFT)
stft = torch.stft(torch.from_numpy(wav_np), n_fft=_SPEC_N_FFT, hop_length=hop,
window=window, return_complex=True)
mag = stft.abs().numpy()
db = 20.0 * np.log10(np.maximum(mag, 1e-8))
db = np.maximum(db, db.max() + _SPEC_DB_FLOOR).astype(np.float32)
# Log-frequency resampling
n_freqs = db.shape[0]
src_idx = np.logspace(0, np.log10(max(n_freqs - 1, 2)), _SPEC_LOG_BINS)
lo = np.floor(src_idx).astype(int).clip(0, n_freqs - 2)
frac = (src_idx - lo)[:, None]
spec = ((1 - frac) * db[lo] + frac * db[lo + 1]).astype(np.float32)
spec = spec[::-1] # low freq at bottom
# Y-tick positions (Hz labels)
tgt_hz = [100, 500, 1000, 2000, 4000, 8000, 16000]
tpos, tlbl = [], []
for hz in tgt_hz:
bin_f = hz * _SPEC_N_FFT / sr
if bin_f < 1 or bin_f >= n_freqs:
continue
pos = int(np.searchsorted(src_idx, bin_f))
tpos.append(_SPEC_LOG_BINS - 1 - min(pos, _SPEC_LOG_BINS - 1))
tlbl.append(f"{hz // 1000}k" if hz >= 1000 else str(hz))
vmin = float(np.percentile(spec, 2.0))
vmax = float(np.percentile(spec, 99.5))
fig = Figure(figsize=(12, 3), dpi=120, tight_layout=True)
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(spec, aspect="auto", cmap="inferno", origin="upper",
vmin=vmin, vmax=vmax, interpolation="antialiased")
ax.set_yticks(tpos)
ax.set_yticklabels(tlbl, fontsize=8)
ax.set_ylabel("Hz", fontsize=9)
ax.set_xlabel("Time frames", fontsize=9)
ax.set_title(path.stem, fontsize=9)
fig.colorbar(im, ax=ax, label="dB", fraction=0.02, pad=0.01)
canvas = FigureCanvasAgg(fig)
canvas.draw()
canvas.print_figure(str(path.with_suffix(".png")), dpi=120)
# ---------------------------------------------------------------------------
# Loss curve rendering
# ---------------------------------------------------------------------------
def _smooth_losses(losses: list[float], beta: float = 0.9) -> list[float]:
"""Exponential moving average smoothing."""
smoothed, ema = [], None
for v in losses:
ema = v if ema is None else beta * ema + (1 - beta) * v
smoothed.append(ema)
return smoothed
def _draw_loss_curve(losses: list[float], log_interval: int,
start_step: int = 0, smoothed: list[float] | None = None) -> Image.Image:
"""Render a loss curve as a PIL Image."""
W, H = 800, 380
pl, pr, pt, pb = 70, 20, 25, 45
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
if len(losses) >= 2:
lo, hi = min(losses), max(losses)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
# Raw loss line
n = len(losses)
pts = []
for i, v in enumerate(losses):
x = pl + int(i * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=(200, 220, 255), width=1)
# Smoothed overlay
if smoothed is not None and len(smoothed) >= 2:
spts = []
for i, v in enumerate(smoothed):
x = pl + int(i * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
spts.append((x, y))
draw.line(spts, fill=(66, 133, 244), width=2)
# x-axis step labels — account for start_step so resumed runs are correct
first_step = start_step + log_interval
last_step = start_step + n * log_interval
for i in range(5):
x = pl + int(i * pw / 4)
step = int(first_step + i * (last_step - first_step) / 4)
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
return img
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
arr = np.array(img).astype(np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
def _prepare_dataset(model: dict, data_dir: Path, device) -> list:
"""Load VAE, encode audio clips, load .npz features.
Returns a list of (latents, clip_features, sync_features, text_clip) CPU tensors.
The VAE is freed after encoding. Call this once and reuse the dataset across
multiple training jobs (e.g. in the scheduler).
"""
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device).eval()
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
prompt_map = _load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
pbar_load = comfy.utils.ProgressBar(len(npz_files))
dataset = []
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
pbar_load.update(1)
continue
bundle = _load_npz(npz_path)
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
x1 = dist.mode().clone().transpose(1, 2).cpu()
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
# Pad/trim clip and sync features to fixed seq lengths — clips from
# shorter videos have fewer frames and would cause stack() to fail
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1)
# VAE no longer needed — free memory
del vae_utils
soft_empty_cache()
if not dataset:
raise ValueError("[LoRA Trainer] No clips could be loaded.")
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
return dataset
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaLoraTrainer:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files and paired audio files.",
}),
"output_dir": ("STRING", {
"default": "lora_output",
"tooltip": "Where to save adapter checkpoints.",
}),
"steps": ("INT", {
"default": 2000, "min": 100, "max": 100000,
"tooltip": "Total training steps.",
}),
"rank": ("INT", {
"default": 16, "min": 1, "max": 128,
"tooltip": "LoRA rank. Higher = more capacity, more VRAM. 16 is a safe default.",
}),
"lr": ("FLOAT", {
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-6,
"tooltip": "Learning rate.",
}),
},
"optional": {
"alpha": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
}),
"target": ("STRING", {
"default": "attn.qkv",
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
}),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"resume_path": ("STRING", {
"default": "",
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
}),
"seed": ("INT", {"default": 42}),
"timestep_mode": (["uniform", "logit_normal", "curriculum"], {
"default": "uniform",
"tooltip": "How to sample training timesteps. "
"uniform: all timesteps equally (matches original MMAudio). "
"logit_normal: concentrates near t=0.5. "
"curriculum: logit_normal for first curriculum_switch% of steps then uniform (recommended for small datasets).",
}),
"logit_normal_sigma": ("FLOAT", {
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
"tooltip": "Spread of the logit-normal distribution. "
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
"Used with logit_normal and curriculum modes.",
}),
"curriculum_switch": ("FLOAT", {
"default": 0.6, "min": 0.1, "max": 0.9, "step": 0.05,
"tooltip": "Fraction of steps to run logit_normal before switching to uniform. "
"0.6 = switch at 60% of total steps. Only used with timestep_mode=curriculum.",
}),
"lora_dropout": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
"0=disabled. 0.050.1 helps regularize on small datasets (arXiv:2404.09610).",
}),
"lora_plus_ratio": ("FLOAT", {
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
}),
"lr_schedule": (["constant", "cosine"], {
"default": "constant",
"tooltip": "LR schedule after warmup. "
"constant: flat LR for all steps. "
"cosine: decay from lr to ~0 following a cosine curve — "
"prevents oscillation when LR is slightly too high.",
}),
},
}
RETURN_TYPES = ("SELVA_MODEL", "STRING", "IMAGE")
RETURN_NAMES = ("model", "adapter_path", "loss_curve")
OUTPUT_TOOLTIPS = (
"Model with trained LoRA adapter applied — connect directly to Sampler.",
"Path to adapter_final.pt — use with SelVA LoRA Loader in future sessions.",
"Training loss curve.",
)
FUNCTION = "train"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Trains a LoRA adapter on a dataset of .npz feature files + paired audio files. "
"Blocks the queue for the duration of training. "
"Prepare the dataset with SelVA Feature Extractor (set a name to get numbered .npz files) "
"and pair each .npz with a clean audio file of the same stem."
)
def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
grad_accum=1, save_every=500, resume_path="", seed=42,
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant"):
torch.manual_seed(seed)
random.seed(seed)
device = get_device()
dtype = model["dtype"]
variant = model["variant"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
data_dir = Path(data_dir.strip())
_out_str = output_dir.strip()
_out_p = Path(_out_str)
# On Windows a Unix-style path like "/lora_output" is technically absolute
# (drive-relative) but the user almost certainly meant a subfolder of the
# ComfyUI output directory. Treat any non-absolute path AND any path whose
# only "absolute" anchor is a leading slash (no drive letter) as relative to
# the ComfyUI output folder.
import sys as _sys
_unix_style_on_windows = (
_sys.platform == "win32"
and _out_p.is_absolute()
and not _out_p.drive # e.g. Path("/foo").drive == "" on Windows
)
if not _out_p.is_absolute() or _unix_style_on_windows:
_out_p = Path(folder_paths.get_output_directory()) / _out_p.relative_to(_out_p.anchor)
print(f"[LoRA Trainer] output_dir resolved to: {_out_p}", flush=True)
output_dir = _out_p
output_dir.mkdir(parents=True, exist_ok=True)
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
target_suffixes = tuple(target.strip().split())
dataset = _prepare_dataset(model, data_dir, device)
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
# can't participate in autograd even with enable_grad — disable inference
# mode entirely so deepcopy, apply_lora, and the training loop all run
# with a clean autograd context.
with torch.inference_mode(False), torch.enable_grad():
r = self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
timestep_mode, logit_normal_sigma, curriculum_switch,
lora_dropout, lora_plus_ratio, lr_schedule,
)
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
):
# --- Prepare generator copy with LoRA ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
target_suffixes=target_suffixes, dropout=lora_dropout)
if n_lora == 0:
raise RuntimeError(
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
"Check the 'target' field."
)
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True)
for name, p in generator.named_parameters():
p.requires_grad_("lora_" in name)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Optimizer + scheduler ---
# LoRA+: split A and B into separate param groups with different LRs.
# ratio=1.0 = standard LoRA (same LR for both). ratio=16 = LoRA+.
lora_A_params = [p for n, p in generator.named_parameters() if "lora_A" in n and p.requires_grad]
lora_B_params = [p for n, p in generator.named_parameters() if "lora_B" in n and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": lora_A_params, "lr": lr},
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
], weight_decay=1e-2)
if lora_plus_ratio != 1.0:
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
if lr_schedule == "cosine":
def lr_lambda(s):
if s < warmup_steps:
return s / max(1, warmup_steps)
progress = (s - warmup_steps) / max(1, steps - warmup_steps)
return max(1e-6 / lr, 0.5 * (1.0 + math.cos(math.pi * progress)))
print(f"[LoRA Trainer] LR schedule: cosine decay {lr:.2e} → 0", flush=True)
else:
def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if resume_path.strip():
ckpt = torch.load(resume_path.strip(), map_location="cpu", weights_only=False)
if "step" not in ckpt:
raise ValueError("[LoRA Trainer] Checkpoint has no step info.")
start_step = ckpt["step"]
if start_step >= steps:
raise ValueError(
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
)
load_lora(generator, ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()
log_interval = 50
remaining = steps - start_step
if remaining < log_interval:
raise ValueError(
f"[LoRA Trainer] Only {remaining} steps remaining (steps={steps}, "
f"start_step={start_step}). Need at least {log_interval} steps to "
"record any loss — increase 'steps' or lower the resume checkpoint."
)
pbar_train = comfy.utils.ProgressBar(remaining)
loss_history = []
running_loss = 0.0
grad_norm_history = []
spectral_metrics = {} # {step: {hf_energy_ratio, spectral_centroid_hz, spectral_rolloff_hz}}
running_grad_norm = 0.0
grad_norm_count = 0
meta = {
"variant": variant,
"rank": rank,
"alpha": alpha_val,
"target": list(target_suffixes),
"steps": steps,
"timestep_mode": timestep_mode,
"logit_normal_sigma": logit_normal_sigma,
"curriculum_switch": curriculum_switch,
"lora_dropout": lora_dropout,
"lora_plus_ratio": lora_plus_ratio,
"lr_schedule": lr_schedule,
}
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
curriculum_switch_step = start_step + int((steps - start_step) * curriculum_switch)
_curriculum_switched = False
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps}, batch_size={batch_size}, "
f"timestep_mode={timestep_mode})\n", flush=True)
last_step = start_step
completed = False
try:
for step in range(start_step + 1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
x1 = generator.normalize(x1)
if timestep_mode == "logit_normal" or (
timestep_mode == "curriculum" and step <= curriculum_switch_step
):
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
t = torch.sigmoid(u)
else:
t = torch.rand(batch_size, device=device, dtype=dtype)
if timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
print(f"[LoRA Trainer] Curriculum switch: logit_normal → uniform at step {step}", flush=True)
_curriculum_switched = True
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
loss.backward()
running_loss += loss.item() * grad_accum
if step % grad_accum == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
lora_A_params + lora_B_params, max_norm=1.0
).item()
running_grad_norm += grad_norm
grad_norm_count += 1
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % log_interval == 0:
skip_flag = output_dir.parent / "skip_current.flag"
if skip_flag.exists():
skip_flag.unlink()
exc = SkipExperiment(f"skip_current.flag detected at step {step} — skipping to next experiment")
exc.partial = {
"loss_history": list(loss_history),
"grad_norm_history": list(grad_norm_history),
"spectral_metrics": dict(spectral_metrics),
"stopped_at_step": step,
}
raise exc
avg = running_loss / log_interval
loss_history.append(avg)
# grad_norm_count can be 0 when grad_accum > log_interval
# (no optimizer step fired in this interval yet)
if grad_norm_count > 0:
avg_gnorm = running_grad_norm / grad_norm_count
grad_norm_history.append(round(avg_gnorm, 6))
gnorm_str = f" grad_norm={avg_gnorm:.4f}"
else:
grad_norm_history.append(None)
gnorm_str = ""
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f}{gnorm_str} "
f"lr={lr_now:.2e} bs={batch_size}", flush=True)
running_loss = 0.0
running_grad_norm = 0.0
grad_norm_count = 0
# Live preview: send updated loss curve to ComfyUI frontend
preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=_smooth_losses(loss_history))
pbar_train.update_absolute(
step - start_step, remaining, ("JPEG", preview_img, 800)
)
if step % save_every == 0 or step == steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": meta,
}, ckpt_path)
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
# Save a quick eval sample in samples/ subfolder
samples_dir = output_dir / "samples"
samples_dir.mkdir(exist_ok=True)
wav, sr = _eval_sample(generator, feature_utils_orig,
dataset, seq_cfg, device, dtype, seed=seed)
if wav is not None:
wav_path = samples_dir / f"step_{step:05d}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
try:
metrics = _spectral_metrics(wav, sr)
spectral_metrics[step] = metrics
print(f"[LoRA Trainer] Spectral: hf_ratio={metrics['hf_energy_ratio']:.3f} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"rolloff={metrics['spectral_rolloff_hz']:.0f}Hz "
f"flatness={metrics['spectral_flatness']:.3f} "
f"temporal_var={metrics['temporal_variance']:.3f}", flush=True)
_save_spectrogram(wav, sr, wav_path)
except Exception as e:
print(f"[LoRA Trainer] Spectral/spectrogram failed: {e}", flush=True)
last_step = step
pbar_train.update(1)
completed = True
finally:
# Save adapter and loss curves whether training completed or was cancelled.
# Skip if we never completed a single step (nothing useful to save).
if loss_history:
if completed:
# Normal completion — use adapter_final.pt (increment if exists)
final_path = output_dir / "adapter_final.pt"
if final_path.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final_path = output_dir / f"adapter_final_{i:03d}.pt"
label = "Done"
else:
# Cancelled — include the step number so the file is useful for resume
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
label = f"Cancelled at step {last_step}"
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
smoothed = _smooth_losses(loss_history)
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
smoothed=smoothed)
raw_img.save(str(output_dir / "loss_raw.png"))
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
# Reached only on normal completion (exception re-raises past this point)
generator.eval()
generator.to(next(model["generator"].parameters()).device)
patched = {**model, "generator": generator}
loss_curve = _pil_to_tensor(smoothed_img)
return {
"patched_model": patched,
"adapter_path": str(final_path),
"loss_curve": loss_curve,
"loss_history": loss_history,
"grad_norm_history": grad_norm_history,
"spectral_metrics": spectral_metrics,
"start_step": start_step,
"meta": meta,
"completed": True,
}