fix: replace std clamp with anchor regularization to prevent OOD noise
The std clamp was post-hoc and only addressed magnitude, not direction. x0 was drifting to mean=-0.55/std=3.1 (ODE expected mean=0/std=1). Replace with anchor_weight * MSE(x0, x0_init) added directly to the loss. The optimizer now balances style matching against staying near the initial N(0,1) noise — gradient-aware, prevents both magnitude and mean drift. Also logs style/anchor losses and x0_std per step for diagnostics. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -152,6 +152,13 @@ class SelvaDittoOptimizer:
|
||||
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
||||
"0.1 adds texture matching but can introduce white noise.",
|
||||
}),
|
||||
"anchor_weight": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
||||
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
|
||||
"Prevents optimization from pushing x0 out of the flow's "
|
||||
"expected distribution (which causes white noise). "
|
||||
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
|
||||
}),
|
||||
"steps": ("INT", {
|
||||
"default": 25, "min": 1, "max": 200,
|
||||
"tooltip": "Euler steps for the final generation pass (after optimization).",
|
||||
@@ -181,7 +188,7 @@ class SelvaDittoOptimizer:
|
||||
|
||||
def optimize(self, model, features, prompt, negative_prompt,
|
||||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||||
normalize=True, target_lufs=-27.0):
|
||||
import traceback
|
||||
|
||||
@@ -286,7 +293,7 @@ class SelvaDittoOptimizer:
|
||||
ref_mean, ref_gram,
|
||||
seq_cfg, sample_rate, device, dtype,
|
||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||||
normalize, target_lufs, pbar,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -312,7 +319,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
ref_mean, ref_gram,
|
||||
seq_cfg, sample_rate, device, dtype,
|
||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||||
normalize, target_lufs, pbar):
|
||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
||||
|
||||
@@ -385,6 +392,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
x0 = torch.nn.Parameter(x0_init.clone())
|
||||
x0_init = x0_init.detach() # anchor — kept fixed, no grad
|
||||
optimizer = torch.optim.Adam([x0], lr=opt_lr)
|
||||
|
||||
# n_grad_steps must not exceed n_ode_steps
|
||||
@@ -449,25 +457,25 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
# clips in the same space. Avoids backprop through VAE decoder which
|
||||
# is @torch.inference_mode() and produces noisy gradients.
|
||||
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
||||
loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
||||
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
||||
|
||||
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
|
||||
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
|
||||
# the ODE into an out-of-distribution region that decodes as white noise.
|
||||
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
|
||||
loss = style_loss + anchor_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
||||
optimizer.step()
|
||||
|
||||
# Clamp x0 std to stay near unit Gaussian — flow matching ODE expects
|
||||
# x0 ~ N(0,1). Optimization can push std >> 1, which maps to an
|
||||
# out-of-distribution initial condition and produces white noise.
|
||||
with torch.no_grad():
|
||||
std = x0.std()
|
||||
if std > 1.5:
|
||||
x0.data.div_(std)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
|
||||
print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True)
|
||||
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
|
||||
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
|
||||
f"x0_std={x0.data.std().item():.3f}", flush=True)
|
||||
|
||||
# ── Final generation with optimized x_0 ─────────────────────────────────
|
||||
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user