feat: add gram_weight param to DITTO, reduce default style_weight to 0.1
White noise on output was caused by the Gram matrix loss pushing the latent into incoherent regions. Now gram_weight defaults to 0 (mean spectrum only) and style_weight defaults to 0.1 instead of 1.0. Users can enable Gram gradually once mean-only optimization converges cleanly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -42,28 +42,30 @@ def _load_wav(path):
|
||||
return wav, sr
|
||||
|
||||
|
||||
def _mel_style_loss(mel_gen, ref_mean, ref_gram):
|
||||
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
||||
"""Style loss between generated mel and precomputed reference statistics.
|
||||
|
||||
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
||||
ref_mean: [n_mels] mean spectrum of BJ reference clips (detached)
|
||||
ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached)
|
||||
|
||||
Mean spectrum loss captures the spectral envelope (which harmonics are
|
||||
boosted). Gram matrix loss captures timbral texture — covariance between
|
||||
frequency bands — without requiring temporal alignment.
|
||||
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
||||
ref_mean: [n_mels] mean spectrum of reference clips (detached)
|
||||
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
|
||||
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
|
||||
Start at 0; enable only if mean-only optimization converges
|
||||
without noise, then increase slowly (0.01–0.1).
|
||||
"""
|
||||
m = mel_gen.squeeze(0) # [n_mels, T]
|
||||
|
||||
# Mean spectrum loss
|
||||
# Mean spectrum loss — captures spectral envelope
|
||||
gen_mean = m.mean(dim=-1) # [n_mels]
|
||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
||||
|
||||
# Gram matrix loss (texture, position-invariant)
|
||||
if gram_weight <= 0.0:
|
||||
return loss_mean
|
||||
|
||||
# Gram matrix loss — captures timbral texture (can add noise if too high)
|
||||
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
|
||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
||||
|
||||
return loss_mean + 0.1 * loss_gram
|
||||
return loss_mean + gram_weight * loss_gram
|
||||
|
||||
|
||||
class SelvaDittoOptimizer:
|
||||
@@ -115,9 +117,15 @@ class SelvaDittoOptimizer:
|
||||
"Must be ≤ n_ode_steps. 5 is a good default.",
|
||||
}),
|
||||
"style_weight": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
||||
"tooltip": "Weight of the BJ style loss. Increase to push harder toward "
|
||||
"BJ style at the cost of coherence with the video.",
|
||||
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
|
||||
"tooltip": "Weight of the BJ style loss. High values push harder toward "
|
||||
"BJ style but add noise. Start at 0.1 and increase slowly.",
|
||||
}),
|
||||
"gram_weight": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
|
||||
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
|
||||
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
||||
"0.1 adds texture matching but can introduce white noise.",
|
||||
}),
|
||||
"steps": ("INT", {
|
||||
"default": 25, "min": 1, "max": 200,
|
||||
@@ -148,7 +156,7 @@ class SelvaDittoOptimizer:
|
||||
|
||||
def optimize(self, model, features, prompt, negative_prompt,
|
||||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||
style_weight, steps, cfg_strength, seed,
|
||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||
normalize=True, target_lufs=-27.0):
|
||||
import traceback
|
||||
|
||||
@@ -244,7 +252,7 @@ class SelvaDittoOptimizer:
|
||||
ref_mean, ref_gram,
|
||||
seq_cfg, sample_rate, device, dtype,
|
||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||
style_weight, steps, cfg_strength, seed,
|
||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||
normalize, target_lufs, pbar,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -270,7 +278,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
ref_mean, ref_gram,
|
||||
seq_cfg, sample_rate, device, dtype,
|
||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||
style_weight, steps, cfg_strength, seed,
|
||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||
normalize, target_lufs, pbar):
|
||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
||||
|
||||
@@ -411,7 +419,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
|
||||
|
||||
# ── Style loss ───────────────────────────────────────────────────────
|
||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||
|
||||
Reference in New Issue
Block a user