diff --git a/nodes/__init__.py b/nodes/__init__.py index 0fcc648..c696387 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -14,6 +14,8 @@ _NODES = { "SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"), "SelvaHfSmoother": (".selva_audio_preprocessors", "SelvaHfSmoother", "SelVA HF Smoother"), "SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"), + "SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"), + "SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index cd9362a..1bae020 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -39,6 +39,10 @@ class SelvaSampler: "default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0, "tooltip": "Target RMS level in dBFS when normalize=True. -27 matches the measured RMS of LUFS-normalized training clips. Increase toward -20 for louder output.", }), + "textual_inversion": ("TEXTUAL_INVERSION", { + "tooltip": "Learned token embeddings from SelVA Textual Inversion Loader. " + "Injects style tokens into CLIP conditioning without modifying model weights.", + }), }, } @@ -49,7 +53,7 @@ class SelvaSampler: CATEGORY = SELVA_CATEGORY DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance." - def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0): + def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0, textual_inversion=None): import dataclasses from selva_core.model.flow_matching import FlowMatching @@ -114,6 +118,18 @@ class SelvaSampler: neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ if negative_prompt.strip() else None + # Inject textual inversion tokens into last K positions of CLIP embedding + if textual_inversion is not None: + emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024] + K = emb.shape[0] + text_clip = text_clip.clone() + text_clip[:, -K:, :] = emb.unsqueeze(0) + if neg_text_clip is not None: + neg_text_clip = neg_text_clip.clone() + neg_text_clip[:, -K:, :] = emb.unsqueeze(0) + print(f"[SelVA] Textual inversion: injected {K} tokens into CLIP conditioning", + flush=True) + conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) empty_conditions = net_generator.get_empty_conditions( bs=1, negative_text_features=neg_text_clip diff --git a/nodes/selva_textual_inversion_loader.py b/nodes/selva_textual_inversion_loader.py new file mode 100644 index 0000000..2626000 --- /dev/null +++ b/nodes/selva_textual_inversion_loader.py @@ -0,0 +1,66 @@ +"""SelVA Textual Inversion Loader. + +Loads a .pt file produced by SelvaTextualInversionTrainer and returns a +TEXTUAL_INVERSION bundle that the SelVA Sampler can inject into text conditioning. +""" + +from pathlib import Path + +import torch +import folder_paths + +from .utils import SELVA_CATEGORY + + +class SelvaTextualInversionLoader: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": ("STRING", { + "default": "textual_inversion.pt", + "tooltip": "Path to a .pt file produced by SelVA Textual Inversion Trainer. " + "Relative paths resolve to the ComfyUI output directory.", + }), + }, + } + + RETURN_TYPES = ("TEXTUAL_INVERSION",) + RETURN_NAMES = ("textual_inversion",) + OUTPUT_TOOLTIPS = ("Learned token embeddings — connect to SelVA Sampler's textual_inversion input.",) + FUNCTION = "load" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "Loads learned CLIP token embeddings produced by SelVA Textual Inversion Trainer. " + "Connect the output to the SelVA Sampler's optional textual_inversion input to guide " + "generation toward the training data style without degrading audio quality." + ) + + def load(self, path: str) -> tuple: + p = Path(path.strip()) + if not p.is_absolute(): + p = Path(folder_paths.get_output_directory()) / p + if not p.exists(): + raise FileNotFoundError(f"[TI Loader] File not found: {p}") + + data = torch.load(str(p), map_location="cpu", weights_only=False) + + embeddings = data["embeddings"] # [K, 1024] + n_tokens = int(data.get("n_tokens", embeddings.shape[0])) + + print(f"[TI Loader] Loaded '{p.name}' n_tokens={n_tokens} " + f"shape={tuple(embeddings.shape)}", flush=True) + if data.get("init_text"): + print(f"[TI Loader] init_text='{data['init_text']}'", flush=True) + if data.get("step"): + print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps " + f"lr={data.get('lr', '?')}", flush=True) + + bundle = { + "embeddings": embeddings, # [K, 1024] float32 on CPU + "n_tokens": n_tokens, + "path": str(p), + "init_text": data.get("init_text", ""), + } + return (bundle,) diff --git a/nodes/selva_textual_inversion_trainer.py b/nodes/selva_textual_inversion_trainer.py new file mode 100644 index 0000000..41c0bdf --- /dev/null +++ b/nodes/selva_textual_inversion_trainer.py @@ -0,0 +1,367 @@ +"""SelVA Textual Inversion Trainer. + +Learns K token embedding vectors in CLIP space that guide the base model +to generate audio in the style of the training clips — without modifying +any model weights. + +Key difference from LoRA: +- ALL generator parameters are frozen (requires_grad=False) +- Only K×1024 token embeddings receive gradients +- Latents stay on the decoder's natural manifold → no quality degradation +- The learned tokens shift WHICH latents are generated, not HOW + +Usage: + 1. Train on your .npz audio features + 2. Load result with SelVA Textual Inversion Loader + 3. Connect to SelVA Sampler optional input +""" + +import copy +import math +import random +import traceback +from pathlib import Path + +import torch +import torchaudio +import comfy.utils +import folder_paths + +from .utils import SELVA_CATEGORY, get_device, soft_empty_cache +from selva_core.model.flow_matching import FlowMatching +from .selva_lora_trainer import ( + _prepare_dataset, + _spectral_metrics, + _save_spectrogram, + _pil_to_tensor, +) + + +# --------------------------------------------------------------------------- +# Eval helper with token injection +# --------------------------------------------------------------------------- + +def _eval_sample_ti(generator, learned_tokens, n_tokens, + feature_utils_orig, dataset, seq_cfg, + device, dtype, num_steps=25, seed=42, clip_idx=0): + """Inference pass with learned tokens injected into text conditioning.""" + generator.eval() + try: + _, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx] + clip_f = clip_f_cpu.to(device, dtype) + sync_f = sync_f_cpu.to(device, dtype) + text_clip = text_clip_cpu.to(device, dtype).clone() + + text_clip[:, -n_tokens:, :] = learned_tokens.detach().unsqueeze(0).to(device, dtype) + + rng = torch.Generator(device=device).manual_seed(seed) + x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, + device=device, dtype=dtype, generator=rng) + + eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) + + def velocity_fn(t, x): + return generator.forward(x, clip_f, sync_f, text_clip, + t.reshape(1).to(device, dtype)) + + with torch.no_grad(): + x1_pred = eval_fm.to_data(velocity_fn, x0) + x1_unnorm = generator.unnormalize(x1_pred) + + orig_dev = next(feature_utils_orig.parameters()).device + if orig_dev != device: + feature_utils_orig.to(device) + try: + spec = feature_utils_orig.decode(x1_unnorm) + audio = feature_utils_orig.vocode(spec) + finally: + if orig_dev != device: + feature_utils_orig.to(orig_dev) + + audio = audio.float().cpu() + if audio.dim() == 2: + audio = audio.unsqueeze(1) + elif audio.dim() == 3 and audio.shape[1] != 1: + audio = audio.mean(dim=1, keepdim=True) + + target_rms = 10 ** (-27.0 / 20.0) + rms = audio.pow(2).mean().sqrt().clamp(min=1e-8) + audio = (audio * (target_rms / rms)) + peak = audio.abs().max().clamp(min=1e-8) + if peak > 1.0: + audio = audio / peak + return audio.squeeze(0), seq_cfg.sampling_rate + + except Exception as e: + print(f"[TI Trainer] Eval sample failed: {e}", flush=True) + traceback.print_exc() + return None, None + finally: + generator.train() + + +# --------------------------------------------------------------------------- +# Node +# --------------------------------------------------------------------------- + +class SelvaTextualInversionTrainer: + """Learns K CLIP token embeddings that steer SelVA toward a target audio style. + + Unlike LoRA, all model weights are frozen. Only the K×1024 embedding tensor + receives gradients, keeping generated latents on the decoder's natural manifold + and preserving base model audio quality while shifting generation style. + """ + + OUTPUT_NODE = True + CATEGORY = SELVA_CATEGORY + FUNCTION = "train" + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("embeddings_path",) + OUTPUT_TOOLTIPS = ("Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",) + DESCRIPTION = ( + "Trains K learnable CLIP token embeddings against your audio dataset " + "with all model weights frozen. The tokens are then injected into the " + "sampler to guide generation toward the training data style without " + "degrading audio quality." + ) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "data_dir": ("STRING", { + "default": "", + "tooltip": "Directory containing .npz feature files and paired audio files (same as LoRA trainer).", + }), + "output_path": ("STRING", { + "default": "textual_inversion.pt", + "tooltip": "Where to save the learned embeddings. Relative paths resolve to ComfyUI output directory.", + }), + "n_tokens": ("INT", { + "default": 4, "min": 1, "max": 16, + "tooltip": "Number of learnable token vectors. More tokens = more expressive but slower to train. 4 is a good default.", + }), + "steps": ("INT", { + "default": 3000, "min": 100, "max": 50000, + "tooltip": "Training steps. 3000 is a reasonable starting point.", + }), + "lr": ("FLOAT", { + "default": 1e-3, "min": 1e-5, "max": 1e-1, "step": 1e-5, + "tooltip": "Learning rate. 1e-3 is a good default for textual inversion (higher than LoRA since there are far fewer parameters).", + }), + "batch_size": ("INT", { + "default": 16, "min": 1, "max": 64, + "tooltip": "Clips sampled per training step.", + }), + "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), + "save_every": ("INT", { + "default": 1000, "min": 100, "max": 10000, + "tooltip": "Save a checkpoint and generate an eval sample every N steps.", + }), + }, + "optional": { + "init_text": ("STRING", { + "default": "", + "tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.", + }), + "warmup_steps": ("INT", { + "default": 100, "min": 0, "max": 1000, + "tooltip": "Linear LR warmup steps.", + }), + }, + } + + def train(self, model, data_dir, output_path, n_tokens, steps, lr, + batch_size, seed, save_every, + init_text="", warmup_steps=100): + + device = get_device() + dtype = model["dtype"] + mode = model["mode"] + seq_cfg = model["seq_cfg"] + feature_utils_orig = model["feature_utils"] + + # --- Resolve paths --- + data_dir = Path(data_dir.strip()) + if not data_dir.is_absolute(): + data_dir = Path(folder_paths.models_dir) / data_dir + if not data_dir.exists(): + raise FileNotFoundError(f"[TI Trainer] data_dir not found: {data_dir}") + + out_path = Path(output_path.strip()) + if not out_path.is_absolute(): + out_path = Path(folder_paths.get_output_directory()) / out_path + out_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"\n[TI Trainer] n_tokens={n_tokens} steps={steps} lr={lr:.2e}", flush=True) + print(f"[TI Trainer] data_dir = {data_dir}", flush=True) + print(f"[TI Trainer] output = {out_path}\n", flush=True) + + # --- Load dataset (reuse LoRA trainer helper) --- + dataset = _prepare_dataset(model, data_dir, device) + + # Training must run outside inference_mode so autograd works + with torch.inference_mode(False), torch.enable_grad(): + return self._train_inner( + model, dataset, feature_utils_orig, seq_cfg, + device, dtype, mode, + data_dir, out_path, + n_tokens, steps, lr, batch_size, + warmup_steps, seed, save_every, init_text, + ) + + def _train_inner( + self, model, dataset, feature_utils_orig, seq_cfg, + device, dtype, mode, + data_dir, out_path, + n_tokens, steps, lr, batch_size, + warmup_steps, seed, save_every, init_text, + ): + torch.manual_seed(seed) + + # --- Generator (frozen) --- + generator = copy.deepcopy(model["generator"]).to(device, dtype) + generator.requires_grad_(False) + generator.update_seq_lengths( + latent_seq_len=seq_cfg.latent_seq_len, + clip_seq_len=seq_cfg.clip_seq_len, + sync_seq_len=seq_cfg.sync_seq_len, + ) + + # --- Init learned tokens --- + # Call encode_text_clip outside the grad context (it has @inference_mode), + # grab values only (no grad needed), then wrap as nn.Parameter. + if init_text.strip(): + with torch.no_grad(): + init_embed = feature_utils_orig.encode_text_clip([init_text.strip()]) + # Positions 1:1+n_tokens — after BOS, before EOS — have actual content + init_vals = init_embed[0, 1:1 + n_tokens, :].detach().clone().float() + if init_vals.shape[0] < n_tokens: + # Prompt was very short; pad remaining with small noise + pad = torch.randn(n_tokens - init_vals.shape[0], init_vals.shape[1]) * 0.02 + init_vals = torch.cat([init_vals, pad], dim=0) + learned_tokens = torch.nn.Parameter(init_vals.to(device, dtype)) + print(f"[TI Trainer] Init from '{init_text.strip()}' (positions 1–{n_tokens})", flush=True) + else: + learned_tokens = torch.nn.Parameter( + torch.randn(n_tokens, 1024, device=device, dtype=dtype) * 0.02 + ) + print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True) + + # --- Optimizer + scheduler --- + optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2) + + def lr_lambda(s): + return s / max(1, warmup_steps) if s < warmup_steps else 1.0 + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25) + + # --- Checkpoint dir --- + ckpt_dir = out_path.parent / out_path.stem + ckpt_dir.mkdir(parents=True, exist_ok=True) + + # --- Training loop --- + generator.train() + optimizer.zero_grad() + + log_interval = 50 + pbar = comfy.utils.ProgressBar(steps) + loss_history = [] + running_loss = 0.0 + + print(f"[TI Trainer] Training {steps} steps batch_size={batch_size}\n", flush=True) + + for step in range(1, steps + 1): + batch = random.choices(dataset, k=batch_size) + x1_list, clip_list, sync_list, text_list = zip(*batch) + + x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) + clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype) + sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) + text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone() + + # Inject learned tokens into last n_tokens positions + text_clip[:, -n_tokens:, :] = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) + + x1 = generator.normalize(x1) + t = torch.rand(batch_size, device=device, dtype=dtype) + x0 = torch.randn_like(x1) + xt = fm.get_conditional_flow(x0, x1, t) + + v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) + loss = fm.loss(v_pred, x0, x1).mean() + loss.backward() + + torch.nn.utils.clip_grad_norm_([learned_tokens], max_norm=1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + running_loss += loss.item() + pbar.update(1) + + if step % log_interval == 0: + avg = running_loss / log_interval + loss_history.append(round(avg, 6)) + running_loss = 0.0 + lr_now = scheduler.get_last_lr()[0] + norm = learned_tokens.norm().item() + print(f"[TI Trainer] step {step:5d}/{steps} " + f"loss={avg:.4f} lr={lr_now:.2e} " + f"token_norm={norm:.4f}", flush=True) + + if step % save_every == 0 or step == steps: + # Save checkpoint + ckpt = { + "embeddings": learned_tokens.detach().cpu(), + "n_tokens": n_tokens, + "step": step, + "init_text": init_text, + "lr": lr, + "steps": steps, + "loss_history": loss_history, + } + ckpt_path = ckpt_dir / f"step_{step:05d}.pt" + torch.save(ckpt, str(ckpt_path)) + + # Eval sample + wav, sr = _eval_sample_ti( + generator, learned_tokens, n_tokens, + feature_utils_orig, dataset, seq_cfg, + device, dtype, seed=seed, + ) + if wav is not None: + wav_path = ckpt_dir / f"step_{step:05d}.wav" + try: + torchaudio.save(str(wav_path), wav, sr) + except RuntimeError: + import soundfile as sf + sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) + + metrics = _spectral_metrics(wav.unsqueeze(0), sr) + if metrics: + img = _save_spectrogram(wav.squeeze(0), sr, ckpt_dir / f"step_{step:05d}.png") + print(f"[TI Trainer] step {step} " + f"centroid={metrics['spectral_centroid_hz']:.0f}Hz " + f"flatness={metrics['spectral_flatness']:.4f} " + f"hf={metrics['hf_energy_ratio']:.3f}", flush=True) + + print(f"[TI Trainer] Checkpoint: {ckpt_path}", flush=True) + + # --- Final save --- + final = { + "embeddings": learned_tokens.detach().cpu(), + "n_tokens": n_tokens, + "step": steps, + "init_text": init_text, + "lr": lr, + "steps": steps, + "loss_history": loss_history, + } + torch.save(final, str(out_path)) + print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True) + + soft_empty_cache() + return (str(out_path),)