feat: add gram_weight param to DITTO, reduce default style_weight to 0.1

White noise on output was caused by the Gram matrix loss pushing the latent
into incoherent regions. Now gram_weight defaults to 0 (mean spectrum only)
and style_weight defaults to 0.1 instead of 1.0. Users can enable Gram
gradually once mean-only optimization converges cleanly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 18:03:32 +02:00
parent 101b1bdb41
commit 608e7df04b
+26 -18
View File
@@ -42,28 +42,30 @@ def _load_wav(path):
return wav, sr return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram): def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss between generated mel and precomputed reference statistics. """Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad) mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of BJ reference clips (detached) ref_mean: [n_mels] mean spectrum of reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached) ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
Mean spectrum loss captures the spectral envelope (which harmonics are Start at 0; enable only if mean-only optimization converges
boosted). Gram matrix loss captures timbral texture — covariance between without noise, then increase slowly (0.010.1).
frequency bands — without requiring temporal alignment.
""" """
m = mel_gen.squeeze(0) # [n_mels, T] m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss # Mean spectrum loss — captures spectral envelope
gen_mean = m.mean(dim=-1) # [n_mels] gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean) loss_mean = F.l1_loss(gen_mean, ref_mean)
# Gram matrix loss (texture, position-invariant) if gram_weight <= 0.0:
return loss_mean
# Gram matrix loss — captures timbral texture (can add noise if too high)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels] gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram) loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + 0.1 * loss_gram return loss_mean + gram_weight * loss_gram
class SelvaDittoOptimizer: class SelvaDittoOptimizer:
@@ -115,9 +117,15 @@ class SelvaDittoOptimizer:
"Must be ≤ n_ode_steps. 5 is a good default.", "Must be ≤ n_ode_steps. 5 is a good default.",
}), }),
"style_weight": ("FLOAT", { "style_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
"tooltip": "Weight of the BJ style loss. Increase to push harder toward " "tooltip": "Weight of the BJ style loss. High values push harder toward "
"BJ style at the cost of coherence with the video.", "BJ style but add noise. Start at 0.1 and increase slowly.",
}),
"gram_weight": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
"0.1 adds texture matching but can introduce white noise.",
}), }),
"steps": ("INT", { "steps": ("INT", {
"default": 25, "min": 1, "max": 200, "default": 25, "min": 1, "max": 200,
@@ -148,7 +156,7 @@ class SelvaDittoOptimizer:
def optimize(self, model, features, prompt, negative_prompt, def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed, style_weight, gram_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0): normalize=True, target_lufs=-27.0):
import traceback import traceback
@@ -244,7 +252,7 @@ class SelvaDittoOptimizer:
ref_mean, ref_gram, ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype, seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed, style_weight, gram_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar, normalize, target_lufs, pbar,
) )
except Exception as e: except Exception as e:
@@ -270,7 +278,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
ref_mean, ref_gram, ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype, seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed, style_weight, gram_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar): normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active).""" """Optimization loop — runs in a fresh thread (no inference_mode active)."""
@@ -411,7 +419,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2)) mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
# ── Style loss ─────────────────────────────────────────────────────── # ── Style loss ───────────────────────────────────────────────────────
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram) loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad loss.backward() # gradient flows through Phase 2 + STE back to x0.grad