fix: replace std clamp with anchor regularization to prevent OOD noise

The std clamp was post-hoc and only addressed magnitude, not direction.
x0 was drifting to mean=-0.55/std=3.1 (ODE expected mean=0/std=1).

Replace with anchor_weight * MSE(x0, x0_init) added directly to the loss.
The optimizer now balances style matching against staying near the initial
N(0,1) noise — gradient-aware, prevents both magnitude and mean drift.

Also logs style/anchor losses and x0_std per step for diagnostics.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 18:30:05 +02:00
parent fa6c4fa834
commit 445da1e69b
+21 -13
View File
@@ -152,6 +152,13 @@ class SelvaDittoOptimizer:
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
"0.1 adds texture matching but can introduce white noise.",
}),
"anchor_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
"Prevents optimization from pushing x0 out of the flow's "
"expected distribution (which causes white noise). "
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
@@ -181,7 +188,7 @@ class SelvaDittoOptimizer:
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, steps, cfg_strength, seed,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
@@ -286,7 +293,7 @@ class SelvaDittoOptimizer:
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, steps, cfg_strength, seed,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
@@ -312,7 +319,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, steps, cfg_strength, seed,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
@@ -385,6 +392,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
x0_init = x0_init.detach() # anchor — kept fixed, no grad
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
@@ -449,25 +457,25 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
# clips in the same space. Avoids backprop through VAE decoder which
# is @torch.inference_mode() and produces noisy gradients.
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
# the ODE into an out-of-distribution region that decodes as white noise.
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
loss = style_loss + anchor_loss
optimizer.zero_grad()
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
# Clamp x0 std to stay near unit Gaussian — flow matching ODE expects
# x0 ~ N(0,1). Optimization can push std >> 1, which maps to an
# out-of-distribution initial condition and produces white noise.
with torch.no_grad():
std = x0.std()
if std > 1.5:
x0.data.div_(std)
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True)
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
f"x0_std={x0.data.std().item():.3f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)