fix: compute DITTO style loss in latent space to eliminate VAE decoder noise
Root cause of white noise: backpropagating through vae.decode produces unstable gradients — the VAE decoder was designed for inference only. Fix: encode reference clips to VAE latent space once (no grad), compute mean + Gram matrix statistics there, and compute style loss directly on net_generator.unnormalize(x) — a single differentiable linear operation. The gradient path is now: loss → x (unnormalized) → ODE → x0, with no decoder in the backward pass. Also adds VAE encoder availability check (fails cleanly if encoder was deleted to save VRAM). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -68,6 +68,31 @@ def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
|||||||
return loss_mean + gram_weight * loss_gram
|
return loss_mean + gram_weight * loss_gram
|
||||||
|
|
||||||
|
|
||||||
|
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
|
||||||
|
"""Style loss computed directly in VAE latent space.
|
||||||
|
|
||||||
|
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
|
||||||
|
ref_mean: [C_lat] mean latent vector of reference clips
|
||||||
|
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
|
||||||
|
gram_weight: weight for Gram component — 0 = mean only (recommended start)
|
||||||
|
|
||||||
|
Operating in latent space avoids backprop through the VAE decoder, which
|
||||||
|
is @torch.inference_mode() and produces noisy, unstable gradients.
|
||||||
|
"""
|
||||||
|
# Mean latent loss — matches average activation per channel
|
||||||
|
gen_mean = z.mean(dim=0) # [C_lat]
|
||||||
|
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
||||||
|
|
||||||
|
if gram_weight <= 0.0:
|
||||||
|
return loss_mean
|
||||||
|
|
||||||
|
# Gram matrix — inter-channel covariance, position-invariant
|
||||||
|
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
|
||||||
|
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
||||||
|
|
||||||
|
return loss_mean + gram_weight * loss_gram
|
||||||
|
|
||||||
|
|
||||||
class SelvaDittoOptimizer:
|
class SelvaDittoOptimizer:
|
||||||
"""DITTO inference-time noise optimization.
|
"""DITTO inference-time noise optimization.
|
||||||
|
|
||||||
@@ -185,7 +210,10 @@ class SelvaDittoOptimizer:
|
|||||||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
||||||
sample_rate = seq_cfg.sampling_rate
|
sample_rate = seq_cfg.sampling_rate
|
||||||
|
|
||||||
# Load and precompute reference mel statistics
|
# Load reference clips and encode to latent space.
|
||||||
|
# Style loss is computed in latent space (after net_generator.unnormalize)
|
||||||
|
# rather than mel space — this avoids backpropagating through the VAE
|
||||||
|
# decoder (which is @torch.inference_mode() and produces noisy gradients).
|
||||||
ref_dir = Path(reference_dir.strip())
|
ref_dir = Path(reference_dir.strip())
|
||||||
if not ref_dir.is_absolute():
|
if not ref_dir.is_absolute():
|
||||||
ref_dir = Path(folder_paths.models_dir) / ref_dir
|
ref_dir = Path(folder_paths.models_dir) / ref_dir
|
||||||
@@ -198,10 +226,16 @@ class SelvaDittoOptimizer:
|
|||||||
if not ref_files:
|
if not ref_files:
|
||||||
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
||||||
|
|
||||||
|
if not hasattr(feature_utils.tod.vae, "encoder"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
|
||||||
|
"Reload the model with the encoder enabled."
|
||||||
|
)
|
||||||
|
|
||||||
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
||||||
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
||||||
|
|
||||||
ref_mels = []
|
ref_latents = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for rf in ref_files:
|
for rf in ref_files:
|
||||||
try:
|
try:
|
||||||
@@ -210,27 +244,27 @@ class SelvaDittoOptimizer:
|
|||||||
wav = wav.mean(0, keepdim=True)
|
wav = wav.mean(0, keepdim=True)
|
||||||
if sr != sample_rate:
|
if sr != sample_rate:
|
||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||||||
wav = wav.squeeze(0).to(device, torch.float32) # cuFFT requires float32
|
wav = wav.squeeze(0).to(device, torch.float32)
|
||||||
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
|
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||||
ref_mels.append(mel)
|
# encode → sample → normalize (matches x at ODE endpoint)
|
||||||
|
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
|
||||||
|
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
|
||||||
|
z_norm = net_generator.normalize(z_sample.to(dtype))
|
||||||
|
ref_latents.append(z_norm.squeeze(0).clone()) # [T_lat, C_lat]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
|
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
|
||||||
|
|
||||||
if not ref_mels:
|
if not ref_latents:
|
||||||
raise RuntimeError("[DITTO] No usable reference clips.")
|
raise RuntimeError("[DITTO] No usable reference clips.")
|
||||||
|
|
||||||
# Precompute reference statistics (done once — detached, no grad)
|
# Precompute reference latent statistics (done once — detached, no grad)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels])
|
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
|
||||||
ref_mean = all_means.mean(0) # [n_mels]
|
ref_mean = all_means.mean(0) # [C_lat]
|
||||||
|
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
|
||||||
|
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
|
||||||
|
|
||||||
all_grams = []
|
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
|
||||||
for m in ref_mels:
|
|
||||||
M = m.squeeze(0) # [n_mels, T]
|
|
||||||
all_grams.append((M @ M.T) / M.shape[-1])
|
|
||||||
ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels]
|
|
||||||
|
|
||||||
print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips "
|
|
||||||
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
|
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
|
||||||
f"grad_steps={n_grad_steps}", flush=True)
|
f"grad_steps={n_grad_steps}", flush=True)
|
||||||
|
|
||||||
@@ -409,17 +443,13 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
)
|
)
|
||||||
x = x + dt * flow
|
x = x + dt * flow
|
||||||
|
|
||||||
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
|
# ── Style loss in latent space ───────────────────────────────────────
|
||||||
# feature_utils.decode and autoencoder.decode are both decorated with
|
# Unnormalize x back to VAE latent space — fully differentiable, no
|
||||||
# @torch.inference_mode(), which destroys the gradient chain.
|
# decode needed. ref_mean/ref_gram are computed from encoded reference
|
||||||
# Bypass both wrappers and call vae.decode directly — it has no
|
# clips in the same space. Avoids backprop through VAE decoder which
|
||||||
# inference_mode decorator and is fully differentiable.
|
# is @torch.inference_mode() and produces noisy gradients.
|
||||||
# The transpose matches feature_utils.decode: [B, T, C] → [B, C, T].
|
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
||||||
x_un = net_generator.unnormalize(x)
|
loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
||||||
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
|
|
||||||
|
|
||||||
# ── Style loss ───────────────────────────────────────────────────────
|
|
||||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||||
|
|||||||
Reference in New Issue
Block a user