Files
ComfyUI-SelVA/nodes/selva_ditto_optimizer.py
T
Ethanfel 101b1bdb41 fix: _do_optimize returns dict not tuple — prevent double-wrapping AUDIO output
optimize() does return (_result[0],) to wrap for ComfyUI. _do_optimize was
returning (dict,) instead of dict, causing double-wrapping: ((dict,),).
ComfyUI then received a tuple as audio and failed on audio["waveform"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:56:59 +02:00

463 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""SelVA DITTO Optimizer.
Inference-time noise optimization: optimizes the initial noise latent x_0
using a style loss against BJ reference clips, backpropagating through the
ODE solver. All model weights remain frozen — only x_0 changes.
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
against BJ reference clips. Runs entirely before the vocoder — optimization
only requires the DiT + VAE decoder, not BigVGAN.
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
forward pass activations) instead of O(N steps). Backward recomputes each
step's activations on demand.
"""
import dataclasses
import threading
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
def _load_wav(path):
"""Load audio file to [channels, samples] float32 tensor."""
try:
return torchaudio.load(str(path))
except Exception:
pass
import soundfile as sf
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram):
"""Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of BJ reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached)
Mean spectrum loss captures the spectral envelope (which harmonics are
boosted). Gram matrix loss captures timbral texture — covariance between
frequency bands — without requiring temporal alignment.
"""
m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss
gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean)
# Gram matrix loss (texture, position-invariant)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + 0.1 * loss_gram
class SelvaDittoOptimizer:
"""DITTO inference-time noise optimization.
Freezes all model weights and optimizes only the initial noise latent x_0
to make the generated audio sound like the BJ reference clips.
No training data or gradient updates to the model — per-video per-run.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description. Leave empty to use features prompt.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
}),
"reference_dir": ("STRING", {
"default": "",
"tooltip": "Directory with BJ reference audio files (.wav/.flac/.mp3). "
"Reference mel statistics are precomputed from these once.",
}),
"n_opt_steps": ("INT", {
"default": 50, "min": 5, "max": 500,
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
"each step requires ~2 DiT forward passes.",
}),
"opt_lr": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 2.0, "step": 0.01,
"tooltip": "Adam learning rate for x_0 optimization. "
"0.1 is the DITTO paper default.",
}),
"n_ode_steps": ("INT", {
"default": 10, "min": 5, "max": 50,
"tooltip": "Euler ODE steps run during each optimization iteration. "
"Lower = faster optimization (1015 is a good trade-off). "
"Final generation always uses the steps parameter below.",
}),
"n_grad_steps": ("INT", {
"default": 5, "min": 1, "max": 50,
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
"Higher = more accurate gradient, more VRAM. "
"Must be ≤ n_ode_steps. 5 is a good default.",
}),
"style_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "Weight of the BJ style loss. Increase to push harder toward "
"BJ style at the cost of coherence with the video.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
}),
"cfg_strength": ("FLOAT", {
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"normalize": ("BOOLEAN", {"default": True}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward BJ style.",)
FUNCTION = "optimize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"DITTO inference-time noise optimization (arXiv:2401.12179). "
"Optimizes the initial noise latent x_0 to match BJ reference clips "
"via mel statistics style loss, backpropagating through the ODE. "
"All model weights frozen — zero quality degradation risk."
)
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
# Validate variant match
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
f"Re-run Feature Extractor."
)
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
# Resolve duration and seq_cfg
duration = features.get("duration", 0)
if duration <= 0:
raise ValueError("[DITTO] Features contain no duration field.")
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
# Load and precompute reference mel statistics
ref_dir = Path(reference_dir.strip())
if not ref_dir.is_absolute():
ref_dir = Path(folder_paths.models_dir) / ref_dir
if not ref_dir.exists():
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
ref_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
ref_files.extend(ref_dir.rglob(ext))
if not ref_files:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device, torch.float32) # cuFFT requires float32
ref_mels = []
with torch.no_grad():
for rf in ref_files[:32]: # cap at 32 for speed
try:
wav, sr = _load_wav(rf)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, torch.float32) # cuFFT requires float32
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
ref_mels.append(mel)
except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
if not ref_mels:
raise RuntimeError("[DITTO] No usable reference clips.")
# Precompute reference statistics (done once — detached, no grad)
with torch.no_grad():
all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels])
ref_mean = all_means.mean(0) # [n_mels]
all_grams = []
for m in ref_mels:
M = m.squeeze(0) # [n_mels, T]
all_grams.append((M @ M.T) / M.shape[-1])
ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels]
print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips "
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
f"grad_steps={n_grad_steps}", flush=True)
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_optimize(
net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_optimize(net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread) and cast to
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
ref_mean = ref_mean.clone().detach().to(dtype)
ref_gram = ref_gram.clone().detach().to(dtype)
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype).clone()
sync_f = features["sync_features"].to(device, dtype).clone()
# Strip inference-mode flags from all model weights and buffers BEFORE any
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
# operations on inference tensors produce inference tensors, so conditions
# computed from tainted weights would also be tainted. clone() outside
# inference_mode produces a normal tensor regardless of the source flag.
def _strip_inference(module):
for mod in module.modules():
for name, buf in list(mod._buffers.items()):
if buf is not None:
mod._buffers[name] = buf.clone()
for name, param in list(mod._parameters.items()):
if param is not None:
mod._parameters[name] = torch.nn.Parameter(
param.data.clone(), requires_grad=False
)
_strip_inference(net_generator)
_strip_inference(feature_utils)
_strip_inference(mel_converter)
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
with torch.no_grad():
text_clip = feature_utils.encode_text_clip([prompt])
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Clone all tensors inside conditions/empty_conditions to ensure no inference
# flags survived from intermediate computations inside preprocess_conditions.
def _clone_nested(obj):
if isinstance(obj, torch.Tensor):
return obj.clone()
elif isinstance(obj, dict):
return {k: _clone_nested(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_clone_nested(v) for v in obj)
return obj
conditions = _clone_nested(conditions)
empty_conditions = _clone_nested(empty_conditions)
# Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
n_grad_steps = min(n_grad_steps, n_ode_steps)
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
print(f"[DITTO] Optimizing x_0 "
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
# Freeze all model weights (double-check — should already be frozen at inference)
net_generator.requires_grad_(False)
feature_utils.requires_grad_(False)
mel_converter.requires_grad_(False)
for opt_step in range(n_opt_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# Detach from x0 so Phase 1 does not build a computation graph.
with torch.no_grad():
x = x0.detach()
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Straight-through estimator: reconnect x to x0's gradient path by
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
# creates a grad_fn pointing back to x0, so loss.backward() will
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
# treated as identity for gradient purposes (truncated BPTT).
#
# x may carry an inference tensor flag from Phase 1 (derived from
# conditions which were built outside inference_mode but may have
# propagated the flag). .clone() strips it so the STE addition does
# not try to save an inference tensor for backward.
x = x.clone()
x = x + (x0 - x0.detach())
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
t = ts[i]
dt = ts[i + 1] - t
# Gradient checkpointing: recompute forward during backward,
# avoiding storage of DiT activations for each step.
def _ode_step(x_in, t=t):
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
flow = torch.utils.checkpoint.checkpoint(
_ode_step, x, use_reentrant=False
)
x = x + dt * flow
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
# feature_utils.decode and autoencoder.decode are both decorated with
# @torch.inference_mode(), which destroys the gradient chain.
# Bypass both wrappers and call vae.decode directly — it has no
# inference_mode decorator and is fully differentiable.
# The transpose matches feature_utils.decode: [B, T, C] → [B, C, T].
x_un = net_generator.unnormalize(x)
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
# ── Style loss ───────────────────────────────────────────────────────
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
optimizer.zero_grad()
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
with torch.no_grad():
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
x = x0.detach()
for i in range(steps):
comfy.model_management.throw_exception_if_processing_interrupted()
t = fm_ts[i]
dt = fm_ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
pbar.update(1)
x1_unnorm = net_generator.unnormalize(x)
spec = feature_utils.decode(x1_unnorm)
audio = feature_utils.vocode(spec)
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
flush=True)
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return {"waveform": audio.cpu(), "sample_rate": sample_rate}