fix: compute DITTO style loss in latent space to eliminate VAE decoder noise
Root cause of white noise: backpropagating through vae.decode produces unstable gradients — the VAE decoder was designed for inference only. Fix: encode reference clips to VAE latent space once (no grad), compute mean + Gram matrix statistics there, and compute style loss directly on net_generator.unnormalize(x) — a single differentiable linear operation. The gradient path is now: loss → x (unnormalized) → ODE → x0, with no decoder in the backward pass. Also adds VAE encoder availability check (fails cleanly if encoder was deleted to save VRAM). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -68,6 +68,31 @@ def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
||||
return loss_mean + gram_weight * loss_gram
|
||||
|
||||
|
||||
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
|
||||
"""Style loss computed directly in VAE latent space.
|
||||
|
||||
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
|
||||
ref_mean: [C_lat] mean latent vector of reference clips
|
||||
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
|
||||
gram_weight: weight for Gram component — 0 = mean only (recommended start)
|
||||
|
||||
Operating in latent space avoids backprop through the VAE decoder, which
|
||||
is @torch.inference_mode() and produces noisy, unstable gradients.
|
||||
"""
|
||||
# Mean latent loss — matches average activation per channel
|
||||
gen_mean = z.mean(dim=0) # [C_lat]
|
||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
||||
|
||||
if gram_weight <= 0.0:
|
||||
return loss_mean
|
||||
|
||||
# Gram matrix — inter-channel covariance, position-invariant
|
||||
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
|
||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
||||
|
||||
return loss_mean + gram_weight * loss_gram
|
||||
|
||||
|
||||
class SelvaDittoOptimizer:
|
||||
"""DITTO inference-time noise optimization.
|
||||
|
||||
@@ -185,7 +210,10 @@ class SelvaDittoOptimizer:
|
||||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
||||
sample_rate = seq_cfg.sampling_rate
|
||||
|
||||
# Load and precompute reference mel statistics
|
||||
# Load reference clips and encode to latent space.
|
||||
# Style loss is computed in latent space (after net_generator.unnormalize)
|
||||
# rather than mel space — this avoids backpropagating through the VAE
|
||||
# decoder (which is @torch.inference_mode() and produces noisy gradients).
|
||||
ref_dir = Path(reference_dir.strip())
|
||||
if not ref_dir.is_absolute():
|
||||
ref_dir = Path(folder_paths.models_dir) / ref_dir
|
||||
@@ -198,10 +226,16 @@ class SelvaDittoOptimizer:
|
||||
if not ref_files:
|
||||
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
||||
|
||||
if not hasattr(feature_utils.tod.vae, "encoder"):
|
||||
raise RuntimeError(
|
||||
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
|
||||
"Reload the model with the encoder enabled."
|
||||
)
|
||||
|
||||
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
||||
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
||||
|
||||
ref_mels = []
|
||||
ref_latents = []
|
||||
with torch.no_grad():
|
||||
for rf in ref_files:
|
||||
try:
|
||||
@@ -210,27 +244,27 @@ class SelvaDittoOptimizer:
|
||||
wav = wav.mean(0, keepdim=True)
|
||||
if sr != sample_rate:
|
||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||
wav = wav.squeeze(0).to(device, torch.float32) # cuFFT requires float32
|
||||
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
|
||||
ref_mels.append(mel)
|
||||
wav = wav.squeeze(0).to(device, torch.float32)
|
||||
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
# encode → sample → normalize (matches x at ODE endpoint)
|
||||
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
|
||||
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
|
||||
z_norm = net_generator.normalize(z_sample.to(dtype))
|
||||
ref_latents.append(z_norm.squeeze(0).clone()) # [T_lat, C_lat]
|
||||
except Exception as e:
|
||||
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
|
||||
|
||||
if not ref_mels:
|
||||
if not ref_latents:
|
||||
raise RuntimeError("[DITTO] No usable reference clips.")
|
||||
|
||||
# Precompute reference statistics (done once — detached, no grad)
|
||||
# Precompute reference latent statistics (done once — detached, no grad)
|
||||
with torch.no_grad():
|
||||
all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels])
|
||||
ref_mean = all_means.mean(0) # [n_mels]
|
||||
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
|
||||
ref_mean = all_means.mean(0) # [C_lat]
|
||||
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
|
||||
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
|
||||
|
||||
all_grams = []
|
||||
for m in ref_mels:
|
||||
M = m.squeeze(0) # [n_mels, T]
|
||||
all_grams.append((M @ M.T) / M.shape[-1])
|
||||
ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels]
|
||||
|
||||
print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips "
|
||||
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
|
||||
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
|
||||
f"grad_steps={n_grad_steps}", flush=True)
|
||||
|
||||
@@ -409,17 +443,13 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
)
|
||||
x = x + dt * flow
|
||||
|
||||
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
|
||||
# feature_utils.decode and autoencoder.decode are both decorated with
|
||||
# @torch.inference_mode(), which destroys the gradient chain.
|
||||
# Bypass both wrappers and call vae.decode directly — it has no
|
||||
# inference_mode decorator and is fully differentiable.
|
||||
# The transpose matches feature_utils.decode: [B, T, C] → [B, C, T].
|
||||
x_un = net_generator.unnormalize(x)
|
||||
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
|
||||
|
||||
# ── Style loss ───────────────────────────────────────────────────────
|
||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight)
|
||||
# ── Style loss in latent space ───────────────────────────────────────
|
||||
# Unnormalize x back to VAE latent space — fully differentiable, no
|
||||
# decode needed. ref_mean/ref_gram are computed from encoded reference
|
||||
# clips in the same space. Avoids backprop through VAE decoder which
|
||||
# is @torch.inference_mode() and produces noisy gradients.
|
||||
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
||||
loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||
|
||||
Reference in New Issue
Block a user