feat: add gram_weight param to DITTO, reduce default style_weight to 0.1
White noise on output was caused by the Gram matrix loss pushing the latent into incoherent regions. Now gram_weight defaults to 0 (mean spectrum only) and style_weight defaults to 0.1 instead of 1.0. Users can enable Gram gradually once mean-only optimization converges cleanly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -42,28 +42,30 @@ def _load_wav(path):
|
|||||||
return wav, sr
|
return wav, sr
|
||||||
|
|
||||||
|
|
||||||
def _mel_style_loss(mel_gen, ref_mean, ref_gram):
|
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
||||||
"""Style loss between generated mel and precomputed reference statistics.
|
"""Style loss between generated mel and precomputed reference statistics.
|
||||||
|
|
||||||
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
||||||
ref_mean: [n_mels] mean spectrum of BJ reference clips (detached)
|
ref_mean: [n_mels] mean spectrum of reference clips (detached)
|
||||||
ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached)
|
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
|
||||||
|
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
|
||||||
Mean spectrum loss captures the spectral envelope (which harmonics are
|
Start at 0; enable only if mean-only optimization converges
|
||||||
boosted). Gram matrix loss captures timbral texture — covariance between
|
without noise, then increase slowly (0.01–0.1).
|
||||||
frequency bands — without requiring temporal alignment.
|
|
||||||
"""
|
"""
|
||||||
m = mel_gen.squeeze(0) # [n_mels, T]
|
m = mel_gen.squeeze(0) # [n_mels, T]
|
||||||
|
|
||||||
# Mean spectrum loss
|
# Mean spectrum loss — captures spectral envelope
|
||||||
gen_mean = m.mean(dim=-1) # [n_mels]
|
gen_mean = m.mean(dim=-1) # [n_mels]
|
||||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
||||||
|
|
||||||
# Gram matrix loss (texture, position-invariant)
|
if gram_weight <= 0.0:
|
||||||
|
return loss_mean
|
||||||
|
|
||||||
|
# Gram matrix loss — captures timbral texture (can add noise if too high)
|
||||||
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
|
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
|
||||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
||||||
|
|
||||||
return loss_mean + 0.1 * loss_gram
|
return loss_mean + gram_weight * loss_gram
|
||||||
|
|
||||||
|
|
||||||
class SelvaDittoOptimizer:
|
class SelvaDittoOptimizer:
|
||||||
@@ -115,9 +117,15 @@ class SelvaDittoOptimizer:
|
|||||||
"Must be ≤ n_ode_steps. 5 is a good default.",
|
"Must be ≤ n_ode_steps. 5 is a good default.",
|
||||||
}),
|
}),
|
||||||
"style_weight": ("FLOAT", {
|
"style_weight": ("FLOAT", {
|
||||||
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
|
||||||
"tooltip": "Weight of the BJ style loss. Increase to push harder toward "
|
"tooltip": "Weight of the BJ style loss. High values push harder toward "
|
||||||
"BJ style at the cost of coherence with the video.",
|
"BJ style but add noise. Start at 0.1 and increase slowly.",
|
||||||
|
}),
|
||||||
|
"gram_weight": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
|
||||||
|
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
|
||||||
|
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
||||||
|
"0.1 adds texture matching but can introduce white noise.",
|
||||||
}),
|
}),
|
||||||
"steps": ("INT", {
|
"steps": ("INT", {
|
||||||
"default": 25, "min": 1, "max": 200,
|
"default": 25, "min": 1, "max": 200,
|
||||||
@@ -148,7 +156,7 @@ class SelvaDittoOptimizer:
|
|||||||
|
|
||||||
def optimize(self, model, features, prompt, negative_prompt,
|
def optimize(self, model, features, prompt, negative_prompt,
|
||||||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||||
style_weight, steps, cfg_strength, seed,
|
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||||
normalize=True, target_lufs=-27.0):
|
normalize=True, target_lufs=-27.0):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@@ -244,7 +252,7 @@ class SelvaDittoOptimizer:
|
|||||||
ref_mean, ref_gram,
|
ref_mean, ref_gram,
|
||||||
seq_cfg, sample_rate, device, dtype,
|
seq_cfg, sample_rate, device, dtype,
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||||
style_weight, steps, cfg_strength, seed,
|
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||||
normalize, target_lufs, pbar,
|
normalize, target_lufs, pbar,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -270,7 +278,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
ref_mean, ref_gram,
|
ref_mean, ref_gram,
|
||||||
seq_cfg, sample_rate, device, dtype,
|
seq_cfg, sample_rate, device, dtype,
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||||
style_weight, steps, cfg_strength, seed,
|
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||||
normalize, target_lufs, pbar):
|
normalize, target_lufs, pbar):
|
||||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
||||||
|
|
||||||
@@ -411,7 +419,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
|
mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2))
|
||||||
|
|
||||||
# ── Style loss ───────────────────────────────────────────────────────
|
# ── Style loss ───────────────────────────────────────────────────────
|
||||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||||
|
|||||||
Reference in New Issue
Block a user