fix: replace std clamp with anchor regularization to prevent OOD noise
The std clamp was post-hoc and only addressed magnitude, not direction. x0 was drifting to mean=-0.55/std=3.1 (ODE expected mean=0/std=1). Replace with anchor_weight * MSE(x0, x0_init) added directly to the loss. The optimizer now balances style matching against staying near the initial N(0,1) noise — gradient-aware, prevents both magnitude and mean drift. Also logs style/anchor losses and x0_std per step for diagnostics. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -152,6 +152,13 @@ class SelvaDittoOptimizer:
|
|||||||
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
||||||
"0.1 adds texture matching but can introduce white noise.",
|
"0.1 adds texture matching but can introduce white noise.",
|
||||||
}),
|
}),
|
||||||
|
"anchor_weight": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
||||||
|
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
|
||||||
|
"Prevents optimization from pushing x0 out of the flow's "
|
||||||
|
"expected distribution (which causes white noise). "
|
||||||
|
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
|
||||||
|
}),
|
||||||
"steps": ("INT", {
|
"steps": ("INT", {
|
||||||
"default": 25, "min": 1, "max": 200,
|
"default": 25, "min": 1, "max": 200,
|
||||||
"tooltip": "Euler steps for the final generation pass (after optimization).",
|
"tooltip": "Euler steps for the final generation pass (after optimization).",
|
||||||
@@ -181,7 +188,7 @@ class SelvaDittoOptimizer:
|
|||||||
|
|
||||||
def optimize(self, model, features, prompt, negative_prompt,
|
def optimize(self, model, features, prompt, negative_prompt,
|
||||||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||||||
normalize=True, target_lufs=-27.0):
|
normalize=True, target_lufs=-27.0):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@@ -286,7 +293,7 @@ class SelvaDittoOptimizer:
|
|||||||
ref_mean, ref_gram,
|
ref_mean, ref_gram,
|
||||||
seq_cfg, sample_rate, device, dtype,
|
seq_cfg, sample_rate, device, dtype,
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||||||
normalize, target_lufs, pbar,
|
normalize, target_lufs, pbar,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -312,7 +319,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
ref_mean, ref_gram,
|
ref_mean, ref_gram,
|
||||||
seq_cfg, sample_rate, device, dtype,
|
seq_cfg, sample_rate, device, dtype,
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||||||
style_weight, gram_weight, steps, cfg_strength, seed,
|
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||||||
normalize, target_lufs, pbar):
|
normalize, target_lufs, pbar):
|
||||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
||||||
|
|
||||||
@@ -384,7 +391,8 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||||
device=device, dtype=dtype,
|
device=device, dtype=dtype,
|
||||||
)
|
)
|
||||||
x0 = torch.nn.Parameter(x0_init.clone())
|
x0 = torch.nn.Parameter(x0_init.clone())
|
||||||
|
x0_init = x0_init.detach() # anchor — kept fixed, no grad
|
||||||
optimizer = torch.optim.Adam([x0], lr=opt_lr)
|
optimizer = torch.optim.Adam([x0], lr=opt_lr)
|
||||||
|
|
||||||
# n_grad_steps must not exceed n_ode_steps
|
# n_grad_steps must not exceed n_ode_steps
|
||||||
@@ -449,25 +457,25 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
# clips in the same space. Avoids backprop through VAE decoder which
|
# clips in the same space. Avoids backprop through VAE decoder which
|
||||||
# is @torch.inference_mode() and produces noisy gradients.
|
# is @torch.inference_mode() and produces noisy gradients.
|
||||||
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
||||||
loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
||||||
|
|
||||||
|
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
|
||||||
|
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
|
||||||
|
# the ODE into an out-of-distribution region that decodes as white noise.
|
||||||
|
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
|
||||||
|
loss = style_loss + anchor_loss
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||||
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Clamp x0 std to stay near unit Gaussian — flow matching ODE expects
|
|
||||||
# x0 ~ N(0,1). Optimization can push std >> 1, which maps to an
|
|
||||||
# out-of-distribution initial condition and produces white noise.
|
|
||||||
with torch.no_grad():
|
|
||||||
std = x0.std()
|
|
||||||
if std > 1.5:
|
|
||||||
x0.data.div_(std)
|
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
|
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
|
||||||
print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True)
|
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
|
||||||
|
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
|
||||||
|
f"x0_std={x0.data.std().item():.3f}", flush=True)
|
||||||
|
|
||||||
# ── Final generation with optimized x_0 ─────────────────────────────────
|
# ── Final generation with optimized x_0 ─────────────────────────────────
|
||||||
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
|
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user