From 633fe36fbbf4b9cc5236da9163c654040b20313b Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 18:12:31 +0200 Subject: [PATCH] fix: compute DITTO style loss in latent space to eliminate VAE decoder noise MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause of white noise: backpropagating through vae.decode produces unstable gradients — the VAE decoder was designed for inference only. Fix: encode reference clips to VAE latent space once (no grad), compute mean + Gram matrix statistics there, and compute style loss directly on net_generator.unnormalize(x) — a single differentiable linear operation. The gradient path is now: loss → x (unnormalized) → ODE → x0, with no decoder in the backward pass. Also adds VAE encoder availability check (fails cleanly if encoder was deleted to save VRAM). Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_ditto_optimizer.py | 84 +++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 27 deletions(-) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 10167da..f00f5b4 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -68,6 +68,31 @@ def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0): return loss_mean + gram_weight * loss_gram +def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0): + """Style loss computed directly in VAE latent space. + + z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad) + ref_mean: [C_lat] mean latent vector of reference clips + ref_gram: [C_lat, C_lat] Gram matrix of reference latents + gram_weight: weight for Gram component — 0 = mean only (recommended start) + + Operating in latent space avoids backprop through the VAE decoder, which + is @torch.inference_mode() and produces noisy, unstable gradients. + """ + # Mean latent loss — matches average activation per channel + gen_mean = z.mean(dim=0) # [C_lat] + loss_mean = F.l1_loss(gen_mean, ref_mean) + + if gram_weight <= 0.0: + return loss_mean + + # Gram matrix — inter-channel covariance, position-invariant + gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat] + loss_gram = F.mse_loss(gram_gen, ref_gram) + + return loss_mean + gram_weight * loss_gram + + class SelvaDittoOptimizer: """DITTO inference-time noise optimization. @@ -185,7 +210,10 @@ class SelvaDittoOptimizer: seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration) sample_rate = seq_cfg.sampling_rate - # Load and precompute reference mel statistics + # Load reference clips and encode to latent space. + # Style loss is computed in latent space (after net_generator.unnormalize) + # rather than mel space — this avoids backpropagating through the VAE + # decoder (which is @torch.inference_mode() and produces noisy gradients). ref_dir = Path(reference_dir.strip()) if not ref_dir.is_absolute(): ref_dir = Path(folder_paths.models_dir) / ref_dir @@ -198,10 +226,16 @@ class SelvaDittoOptimizer: if not ref_files: raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}") + if not hasattr(feature_utils.tod.vae, "encoder"): + raise RuntimeError( + "[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. " + "Reload the model with the encoder enabled." + ) + print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True) mel_converter.to(device, torch.float32) # cuFFT requires float32 - ref_mels = [] + ref_latents = [] with torch.no_grad(): for rf in ref_files: try: @@ -210,27 +244,27 @@ class SelvaDittoOptimizer: wav = wav.mean(0, keepdim=True) if sr != sample_rate: wav = torchaudio.functional.resample(wav, sr, sample_rate) - wav = wav.squeeze(0).to(device, torch.float32) # cuFFT requires float32 - mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T] - ref_mels.append(mel) + wav = wav.squeeze(0).to(device, torch.float32) + mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T_mel] + # encode → sample → normalize (matches x at ODE endpoint) + z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution + z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat] + z_norm = net_generator.normalize(z_sample.to(dtype)) + ref_latents.append(z_norm.squeeze(0).clone()) # [T_lat, C_lat] except Exception as e: print(f" [DITTO] Skip {rf.name}: {e}", flush=True) - if not ref_mels: + if not ref_latents: raise RuntimeError("[DITTO] No usable reference clips.") - # Precompute reference statistics (done once — detached, no grad) + # Precompute reference latent statistics (done once — detached, no grad) with torch.no_grad(): - all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels]) - ref_mean = all_means.mean(0) # [n_mels] + all_means = torch.stack([z.mean(dim=0) for z in ref_latents]) + ref_mean = all_means.mean(0) # [C_lat] + all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents] + ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat] - all_grams = [] - for m in ref_mels: - M = m.squeeze(0) # [n_mels, T] - all_grams.append((M @ M.T) / M.shape[-1]) - ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels] - - print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips " + print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips " f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} " f"grad_steps={n_grad_steps}", flush=True) @@ -409,17 +443,13 @@ def _do_optimize(net_generator, feature_utils, mel_converter, ) x = x + dt * flow - # ── Decode to mel (no vocoder — cheap) ────────────────────────────── - # feature_utils.decode and autoencoder.decode are both decorated with - # @torch.inference_mode(), which destroys the gradient chain. - # Bypass both wrappers and call vae.decode directly — it has no - # inference_mode decorator and is fully differentiable. - # The transpose matches feature_utils.decode: [B, T, C] → [B, C, T]. - x_un = net_generator.unnormalize(x) - mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2)) - - # ── Style loss ─────────────────────────────────────────────────────── - loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight) + # ── Style loss in latent space ─────────────────────────────────────── + # Unnormalize x back to VAE latent space — fully differentiable, no + # decode needed. ref_mean/ref_gram are computed from encoded reference + # clips in the same space. Avoids backprop through VAE decoder which + # is @torch.inference_mode() and produces noisy gradients. + x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat] + loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight) optimizer.zero_grad() loss.backward() # gradient flows through Phase 2 + STE back to x0.grad