diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index 247f01a..b529871 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -152,6 +152,13 @@ class SelvaDittoOptimizer: "the mean spectrum loss. 0 = mean spectrum only (less noise). " "0.1 adds texture matching but can introduce white noise.", }), + "anchor_weight": ("FLOAT", { + "default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, + "tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. " + "Prevents optimization from pushing x0 out of the flow's " + "expected distribution (which causes white noise). " + "Higher = cleaner audio, weaker style. 1.0 is a safe default.", + }), "steps": ("INT", { "default": 25, "min": 1, "max": 200, "tooltip": "Euler steps for the final generation pass (after optimization).", @@ -181,7 +188,7 @@ class SelvaDittoOptimizer: def optimize(self, model, features, prompt, negative_prompt, reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, - style_weight, gram_weight, steps, cfg_strength, seed, + style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0): import traceback @@ -286,7 +293,7 @@ class SelvaDittoOptimizer: ref_mean, ref_gram, seq_cfg, sample_rate, device, dtype, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, - style_weight, gram_weight, steps, cfg_strength, seed, + style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed, normalize, target_lufs, pbar, ) except Exception as e: @@ -312,7 +319,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter, ref_mean, ref_gram, seq_cfg, sample_rate, device, dtype, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, - style_weight, gram_weight, steps, cfg_strength, seed, + style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed, normalize, target_lufs, pbar): """Optimization loop — runs in a fresh thread (no inference_mode active).""" @@ -384,7 +391,8 @@ def _do_optimize(net_generator, feature_utils, mel_converter, 1, seq_cfg.latent_seq_len, net_generator.latent_dim, device=device, dtype=dtype, ) - x0 = torch.nn.Parameter(x0_init.clone()) + x0 = torch.nn.Parameter(x0_init.clone()) + x0_init = x0_init.detach() # anchor — kept fixed, no grad optimizer = torch.optim.Adam([x0], lr=opt_lr) # n_grad_steps must not exceed n_ode_steps @@ -449,25 +457,25 @@ def _do_optimize(net_generator, feature_utils, mel_converter, # clips in the same space. Avoids backprop through VAE decoder which # is @torch.inference_mode() and produces noisy gradients. x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat] - loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight) + style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight) + + # Anchor regularization — penalize x0 drifting from its initial N(0,1) + # value. Flow matching ODE expects x0 ~ N(0,1); large deviations push + # the ODE into an out-of-distribution region that decodes as white noise. + anchor_loss = anchor_weight * F.mse_loss(x0, x0_init) + loss = style_loss + anchor_loss optimizer.zero_grad() loss.backward() # gradient flows through Phase 2 + STE back to x0.grad torch.nn.utils.clip_grad_norm_([x0], 1.0) optimizer.step() - # Clamp x0 std to stay near unit Gaussian — flow matching ODE expects - # x0 ~ N(0,1). Optimization can push std >> 1, which maps to an - # out-of-distribution initial condition and produces white noise. - with torch.no_grad(): - std = x0.std() - if std > 1.5: - x0.data.div_(std) - pbar.update(1) if (opt_step + 1) % max(1, n_opt_steps // 10) == 0: - print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True) + print(f"[DITTO] {opt_step+1}/{n_opt_steps} " + f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} " + f"x0_std={x0.data.std().item():.3f}", flush=True) # ── Final generation with optimized x_0 ───────────────────────────────── print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)