"""SelVA DITTO Optimizer. Inference-time noise optimization: optimizes the initial noise latent x_0 using a style loss against BJ reference clips, backpropagating through the ODE solver. All model weights remain frozen — only x_0 changes. Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179, ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE. Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix) against BJ reference clips. Runs entirely before the vocoder — optimization only requires the DiT + VAE decoder, not BigVGAN. Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT forward pass activations) instead of O(N steps). Backward recomputes each step's activations on demand. """ import dataclasses import threading from pathlib import Path import torch import torch.nn.functional as F import torchaudio import comfy.utils import comfy.model_management import folder_paths from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache def _load_wav(path): """Load audio file to [channels, samples] float32 tensor.""" try: return torchaudio.load(str(path)) except Exception: pass import soundfile as sf data, sr = sf.read(str(path), dtype="float32", always_2d=True) wav = torch.from_numpy(data.T) return wav, sr def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0): """Style loss between generated mel and precomputed reference statistics. mel_gen: [1, n_mels, T] generated mel spectrogram (with grad) ref_mean: [n_mels] mean spectrum of reference clips (detached) ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached) gram_weight: weight for Gram matrix component — 0 = mean spectrum only. Start at 0; enable only if mean-only optimization converges without noise, then increase slowly (0.01–0.1). """ m = mel_gen.squeeze(0) # [n_mels, T] # Mean spectrum loss — captures spectral envelope gen_mean = m.mean(dim=-1) # [n_mels] loss_mean = F.l1_loss(gen_mean, ref_mean) if gram_weight <= 0.0: return loss_mean # Gram matrix loss — captures timbral texture (can add noise if too high) gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels] loss_gram = F.mse_loss(gram_gen, ref_gram) return loss_mean + gram_weight * loss_gram def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0): """Style loss computed directly in VAE latent space. z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad) ref_mean: [C_lat] mean latent vector of reference clips ref_gram: [C_lat, C_lat] Gram matrix of reference latents gram_weight: weight for Gram component — 0 = mean only (recommended start) Operating in latent space avoids backprop through the VAE decoder, which is @torch.inference_mode() and produces noisy, unstable gradients. """ # Mean latent loss — matches average activation per channel gen_mean = z.mean(dim=0) # [C_lat] loss_mean = F.l1_loss(gen_mean, ref_mean) if gram_weight <= 0.0: return loss_mean # Gram matrix — inter-channel covariance, position-invariant gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat] loss_gram = F.mse_loss(gram_gen, ref_gram) return loss_mean + gram_weight * loss_gram class SelvaDittoOptimizer: """DITTO inference-time noise optimization. Freezes all model weights and optimizes only the initial noise latent x_0 to make the generated audio sound like the BJ reference clips. No training data or gradient updates to the model — per-video per-run. """ @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("SELVA_MODEL",), "features": ("SELVA_FEATURES",), "prompt": ("STRING", { "default": "", "multiline": True, "tooltip": "Sound description. Leave empty to use features prompt.", }), "negative_prompt": ("STRING", { "default": "", "multiline": False, }), "reference_dir": ("STRING", { "default": "", "tooltip": "Directory with BJ reference audio files (.wav/.flac/.mp3). " "Reference mel statistics are precomputed from these once.", }), "n_opt_steps": ("INT", { "default": 50, "min": 5, "max": 500, "tooltip": "Gradient optimization steps on x_0. 50 is a good start; " "each step requires ~2 DiT forward passes.", }), "opt_lr": ("FLOAT", { "default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001, "tooltip": "Adam learning rate for x_0 optimization. " "0.02–0.05 is recommended; 0.1 (paper default) causes oscillation.", }), "n_ode_steps": ("INT", { "default": 10, "min": 5, "max": 50, "tooltip": "Euler ODE steps run during each optimization iteration. " "Lower = faster optimization (10–15 is a good trade-off). " "Final generation always uses the steps parameter below.", }), "n_grad_steps": ("INT", { "default": 5, "min": 1, "max": 50, "tooltip": "ODE steps to differentiate through (truncated BPTT). " "Higher = more accurate gradient, more VRAM. " "Must be ≤ n_ode_steps. 5 is a good default.", }), "style_weight": ("FLOAT", { "default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05, "tooltip": "Weight of the BJ style loss. High values push harder toward " "BJ style but add noise. Start at 0.1 and increase slowly.", }), "gram_weight": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Weight of the Gram matrix (timbral texture) loss relative to " "the mean spectrum loss. 0 = mean spectrum only (less noise). " "0.1 adds texture matching but can introduce white noise.", }), "anchor_weight": ("FLOAT", { "default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. " "Prevents optimization from pushing x0 out of the flow's " "expected distribution (which causes white noise). " "Higher = cleaner audio, weaker style. 1.0 is a safe default.", }), "steps": ("INT", { "default": 25, "min": 1, "max": 200, "tooltip": "Euler steps for the final generation pass (after optimization).", }), "cfg_strength": ("FLOAT", { "default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, "optional": { "normalize": ("BOOLEAN", {"default": True}), "target_lufs": ("FLOAT", { "default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}), }, } RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward BJ style.",) FUNCTION = "optimize" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "DITTO inference-time noise optimization (arXiv:2401.12179). " "Optimizes the initial noise latent x_0 to match BJ reference clips " "via mel statistics style loss, backpropagating through the ODE. " "All model weights frozen — zero quality degradation risk." ) def optimize(self, model, features, prompt, negative_prompt, reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0): import traceback device = get_device() dtype = model["dtype"] strategy = model["strategy"] net_generator = model["generator"] feature_utils = model["feature_utils"] mel_converter = feature_utils.mel_converter # Validate variant match feat_variant = features.get("variant") if feat_variant is not None and feat_variant != model["variant"]: raise ValueError( f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. " f"Re-run Feature Extractor." ) if not prompt or not prompt.strip(): prompt = features.get("prompt", "") # Resolve duration and seq_cfg duration = features.get("duration", 0) if duration <= 0: raise ValueError("[DITTO] Features contain no duration field.") seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration) sample_rate = seq_cfg.sampling_rate # Load reference clips and encode to latent space. # Style loss is computed in latent space (after net_generator.unnormalize) # rather than mel space — this avoids backpropagating through the VAE # decoder (which is @torch.inference_mode() and produces noisy gradients). ref_dir = Path(reference_dir.strip()) if not ref_dir.is_absolute(): ref_dir = Path(folder_paths.models_dir) / ref_dir if not ref_dir.exists(): raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}") ref_files = [] for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"): ref_files.extend(ref_dir.rglob(ext)) if not ref_files: raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}") if not hasattr(feature_utils.tod.vae, "encoder"): raise RuntimeError( "[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. " "Reload the model with the encoder enabled." ) print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True) mel_converter.to(device, torch.float32) # cuFFT requires float32 ref_latents = [] with torch.no_grad(): for rf in ref_files: try: wav, sr = _load_wav(rf) if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True) if sr != sample_rate: wav = torchaudio.functional.resample(wav, sr, sample_rate) wav = wav.squeeze(0).to(device, torch.float32) mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel] # encode → sample → VAE latent space (matches unnormalize(x) in loss) z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat] ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat] except Exception as e: print(f" [DITTO] Skip {rf.name}: {e}", flush=True) if not ref_latents: raise RuntimeError("[DITTO] No usable reference clips.") # Precompute reference latent statistics (done once — detached, no grad) with torch.no_grad(): all_means = torch.stack([z.mean(dim=0) for z in ref_latents]) ref_mean = all_means.mean(0) # [C_lat] all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents] ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat] print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips " f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} " f"grad_steps={n_grad_steps}", flush=True) if strategy == "offload_to_cpu": net_generator.to(device) feature_utils.to(device) soft_empty_cache() pbar = comfy.utils.ProgressBar(n_opt_steps + steps) _result = [None] _exc = [None] def _worker(): try: _result[0] = _do_optimize( net_generator, feature_utils, mel_converter, features, prompt, negative_prompt, ref_mean, ref_gram, seq_cfg, sample_rate, device, dtype, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed, normalize, target_lufs, pbar, ) except Exception as e: _exc[0] = e traceback.print_exc() t = threading.Thread(target=_worker, daemon=True) t.start() t.join() if strategy == "offload_to_cpu": net_generator.to(get_offload_device()) feature_utils.to(get_offload_device()) soft_empty_cache() if _exc[0] is not None: raise _exc[0] return (_result[0],) def _do_optimize(net_generator, feature_utils, mel_converter, features, prompt, negative_prompt, ref_mean, ref_gram, seq_cfg, sample_rate, device, dtype, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed, normalize, target_lufs, pbar): """Optimization loop — runs in a fresh thread (no inference_mode active).""" # Strip inference flags from ref stats (came from main thread) and cast to # model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path); # mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient → # "Found dtype Float but expected BFloat16" in backward through bfloat16 ops. ref_mean = ref_mean.clone().detach().to(dtype) ref_gram = ref_gram.clone().detach().to(dtype) torch.manual_seed(seed) clip_f = features["clip_features"].to(device, dtype).clone() sync_f = features["sync_features"].to(device, dtype).clone() # Strip inference-mode flags from all model weights and buffers BEFORE any # forward pass. Parameters were loaded in ComfyUI's inference_mode context; # operations on inference tensors produce inference tensors, so conditions # computed from tainted weights would also be tainted. clone() outside # inference_mode produces a normal tensor regardless of the source flag. def _strip_inference(module): for mod in module.modules(): for name, buf in list(mod._buffers.items()): if buf is not None: mod._buffers[name] = buf.clone() for name, param in list(mod._parameters.items()): if param is not None: mod._parameters[name] = torch.nn.Parameter( param.data.clone(), requires_grad=False ) _strip_inference(net_generator) _strip_inference(feature_utils) _strip_inference(mel_converter) net_generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, clip_seq_len=clip_f.shape[1], sync_seq_len=sync_f.shape[1], ) with torch.no_grad(): text_clip = feature_utils.encode_text_clip([prompt]) neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ if negative_prompt.strip() else None conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) empty_conditions = net_generator.get_empty_conditions( bs=1, negative_text_features=neg_text_clip ) # Clone all tensors inside conditions/empty_conditions to ensure no inference # flags survived from intermediate computations inside preprocess_conditions. def _clone_nested(obj): if isinstance(obj, torch.Tensor): return obj.clone() elif isinstance(obj, dict): return {k: _clone_nested(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return type(obj)(_clone_nested(v) for v in obj) return obj conditions = _clone_nested(conditions) empty_conditions = _clone_nested(empty_conditions) # Initial noise — x_0 is the parameter we optimize x0_init = torch.randn( 1, seq_cfg.latent_seq_len, net_generator.latent_dim, device=device, dtype=dtype, ) x0 = torch.nn.Parameter(x0_init.clone()) x0_init = x0_init.detach() # anchor — kept fixed, no grad optimizer = torch.optim.Adam([x0], lr=opt_lr) # n_grad_steps must not exceed n_ode_steps n_grad_steps = min(n_grad_steps, n_ode_steps) n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype) print(f"[DITTO] Optimizing x_0 " f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True) # Freeze all model weights (double-check — should already be frozen at inference) net_generator.requires_grad_(False) feature_utils.requires_grad_(False) mel_converter.requires_grad_(False) for opt_step in range(n_opt_steps): comfy.model_management.throw_exception_if_processing_interrupted() # ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ── # Detach from x0 so Phase 1 does not build a computation graph. with torch.no_grad(): x = x0.detach() for i in range(n_free_steps): t = ts[i] dt = ts[i + 1] - t flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) x = x + dt * flow # Straight-through estimator: reconnect x to x0's gradient path by # adding the zero tensor (x0 - x0.detach()). This adds zero value but # creates a grad_fn pointing back to x0, so loss.backward() will # propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad. # The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is # treated as identity for gradient purposes (truncated BPTT). # # x may carry an inference tensor flag from Phase 1 (derived from # conditions which were built outside inference_mode but may have # propagated the flag). .clone() strips it so the STE addition does # not try to save an inference tensor for backward. x = x.clone() x = x + (x0 - x0.detach()) # ── Phase 2: run last n_grad_steps with gradient + checkpointing ── for i in range(n_free_steps, n_ode_steps): t = ts[i] dt = ts[i + 1] - t # Gradient checkpointing: recompute forward during backward, # avoiding storage of DiT activations for each step. def _ode_step(x_in, t=t): return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength) flow = torch.utils.checkpoint.checkpoint( _ode_step, x, use_reentrant=False ) x = x + dt * flow # ── Style loss in latent space ─────────────────────────────────────── # Unnormalize x back to VAE latent space — fully differentiable, no # decode needed. ref_mean/ref_gram are computed from encoded reference # clips in the same space. Avoids backprop through VAE decoder which # is @torch.inference_mode() and produces noisy gradients. x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat] style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight) # Anchor regularization — penalize x0 drifting from its initial N(0,1) # value. Flow matching ODE expects x0 ~ N(0,1); large deviations push # the ODE into an out-of-distribution region that decodes as white noise. anchor_loss = anchor_weight * F.mse_loss(x0, x0_init) loss = style_loss + anchor_loss optimizer.zero_grad() loss.backward() # gradient flows through Phase 2 + STE back to x0.grad torch.nn.utils.clip_grad_norm_([x0], 1.0) optimizer.step() pbar.update(1) if (opt_step + 1) % max(1, n_opt_steps // 10) == 0: print(f"[DITTO] {opt_step+1}/{n_opt_steps} " f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} " f"x0_std={x0.data.std().item():.3f}", flush=True) # ── Final generation with optimized x_0 ───────────────────────────────── print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True) with torch.no_grad(): fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype) x = x0.detach() for i in range(steps): comfy.model_management.throw_exception_if_processing_interrupted() t = fm_ts[i] dt = fm_ts[i + 1] - t flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) x = x + dt * flow pbar.update(1) x1_unnorm = net_generator.unnormalize(x) spec = feature_utils.decode(x1_unnorm) audio = feature_utils.vocode(spec) print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}", flush=True) audio = audio.float() if audio.dim() == 2: audio = audio.unsqueeze(1) elif audio.dim() == 3 and audio.shape[1] != 1: audio = audio.mean(dim=1, keepdim=True) if normalize: target_rms = 10 ** (target_lufs / 20.0) rms = audio.pow(2).mean().sqrt().clamp(min=1e-8) audio = audio * (target_rms / rms) peak = audio.abs().max().clamp(min=1e-8) if peak > 1.0: audio = audio / peak print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True) return {"waveform": audio.cpu(), "sample_rate": sample_rate}