diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index ded353a..1d5b462 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -42,28 +42,30 @@ def _load_wav(path): return wav, sr -def _mel_style_loss(mel_gen, ref_mean, ref_gram): +def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0): """Style loss between generated mel and precomputed reference statistics. - mel_gen: [1, n_mels, T] generated mel spectrogram (with grad) - ref_mean: [n_mels] mean spectrum of BJ reference clips (detached) - ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached) - - Mean spectrum loss captures the spectral envelope (which harmonics are - boosted). Gram matrix loss captures timbral texture — covariance between - frequency bands — without requiring temporal alignment. + mel_gen: [1, n_mels, T] generated mel spectrogram (with grad) + ref_mean: [n_mels] mean spectrum of reference clips (detached) + ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached) + gram_weight: weight for Gram matrix component — 0 = mean spectrum only. + Start at 0; enable only if mean-only optimization converges + without noise, then increase slowly (0.01–0.1). """ m = mel_gen.squeeze(0) # [n_mels, T] - # Mean spectrum loss + # Mean spectrum loss — captures spectral envelope gen_mean = m.mean(dim=-1) # [n_mels] loss_mean = F.l1_loss(gen_mean, ref_mean) - # Gram matrix loss (texture, position-invariant) + if gram_weight <= 0.0: + return loss_mean + + # Gram matrix loss — captures timbral texture (can add noise if too high) gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels] loss_gram = F.mse_loss(gram_gen, ref_gram) - return loss_mean + 0.1 * loss_gram + return loss_mean + gram_weight * loss_gram class SelvaDittoOptimizer: @@ -115,9 +117,15 @@ class SelvaDittoOptimizer: "Must be ≤ n_ode_steps. 5 is a good default.", }), "style_weight": ("FLOAT", { - "default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, - "tooltip": "Weight of the BJ style loss. Increase to push harder toward " - "BJ style at the cost of coherence with the video.", + "default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05, + "tooltip": "Weight of the BJ style loss. High values push harder toward " + "BJ style but add noise. Start at 0.1 and increase slowly.", + }), + "gram_weight": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, + "tooltip": "Weight of the Gram matrix (timbral texture) loss relative to " + "the mean spectrum loss. 0 = mean spectrum only (less noise). " + "0.1 adds texture matching but can introduce white noise.", }), "steps": ("INT", { "default": 25, "min": 1, "max": 200, @@ -148,7 +156,7 @@ class SelvaDittoOptimizer: def optimize(self, model, features, prompt, negative_prompt, reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, - style_weight, steps, cfg_strength, seed, + style_weight, gram_weight, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0): import traceback @@ -244,7 +252,7 @@ class SelvaDittoOptimizer: ref_mean, ref_gram, seq_cfg, sample_rate, device, dtype, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, - style_weight, steps, cfg_strength, seed, + style_weight, gram_weight, steps, cfg_strength, seed, normalize, target_lufs, pbar, ) except Exception as e: @@ -270,7 +278,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter, ref_mean, ref_gram, seq_cfg, sample_rate, device, dtype, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, - style_weight, steps, cfg_strength, seed, + style_weight, gram_weight, steps, cfg_strength, seed, normalize, target_lufs, pbar): """Optimization loop — runs in a fresh thread (no inference_mode active).""" @@ -411,7 +419,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter, mel_gen = feature_utils.tod.vae.decode(x_un.transpose(1, 2)) # ── Style loss ─────────────────────────────────────────────────────── - loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram) + loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight) optimizer.zero_grad() loss.backward() # gradient flows through Phase 2 + STE back to x0.grad