Use sqrt(cfg) scaling for K + store raw prev_eps
Full K corrupts at high CFG (output correction = cfg*K = 2.4 at cfg=12). K/cfg was too weak (0.2 at cfg=12). The paper only tested up to cfg=7.5 where output corrections range 0.5-1.5. K/sqrt(cfg) keeps output correction = sqrt(cfg)*K growing sub-linearly, giving 0.69 at cfg=12 — within the paper's working range. Also store raw (pre-correction) guidance as prev_eps to prevent correction accumulation through the sliding surface. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
56
nodes.py
56
nodes.py
@@ -37,24 +37,17 @@ class SMCCFGCtrl:
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
def patch(self, model, smc_cfg_lambda, smc_cfg_K, warmup_steps):
|
||||
# Mutable state persisted across denoising steps via closure
|
||||
state = {
|
||||
"prev_eps": None,
|
||||
"step": 0,
|
||||
"prev_sigma": None,
|
||||
}
|
||||
|
||||
state = {"prev_eps": None, "step": 0, "prev_sigma": None}
|
||||
lam = smc_cfg_lambda
|
||||
K = smc_cfg_K
|
||||
|
||||
def smc_cfg_function(args):
|
||||
cond = args["cond"] # x - cond_denoised (sigma-scaled noise)
|
||||
uncond = args["uncond"] # x - uncond_denoised (sigma-scaled noise)
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
sigma = args["sigma"]
|
||||
|
||||
# Detect new generation: sigma should decrease monotonically during
|
||||
# denoising. If it jumps up, a new sampling run has started.
|
||||
# Detect new generation (sigma jumps up = new sampling run)
|
||||
curr_sigma = sigma.max().item() if torch.is_tensor(sigma) else float(sigma)
|
||||
if state["prev_sigma"] is not None and curr_sigma > state["prev_sigma"] * 1.1:
|
||||
state["prev_eps"] = None
|
||||
@@ -64,21 +57,13 @@ class SMCCFGCtrl:
|
||||
step = state["step"]
|
||||
state["step"] = step + 1
|
||||
|
||||
# Warmup: pure conditional prediction (no guidance)
|
||||
if warmup_steps > 0 and step < warmup_steps:
|
||||
return cond
|
||||
|
||||
# Normalize to noise-prediction space by dividing out sigma.
|
||||
# The paper's K is calibrated for unit-variance noise predictions.
|
||||
# ComfyUI's cond/uncond are (x - denoised) ≈ sigma * epsilon,
|
||||
# so dividing by sigma recovers epsilon-space where K=0.2 is correct.
|
||||
# Crucially, when converting back, the sigma factor naturally dampens
|
||||
# the correction at late steps (small sigma), preventing noise injection.
|
||||
# Normalize to noise-prediction space (divide out sigma).
|
||||
sigma_val = max(curr_sigma, 1e-8)
|
||||
guidance_eps = (cond - uncond) / sigma_val
|
||||
|
||||
# Initialize prev_eps on first SMC step (matches original paper
|
||||
# where SMC correction is applied from the very first step)
|
||||
if state["prev_eps"] is None:
|
||||
state["prev_eps"] = guidance_eps.detach().clone()
|
||||
|
||||
@@ -87,29 +72,26 @@ class SMCCFGCtrl:
|
||||
# Sliding surface: s_t = (e_t - e_{t-1}) + lambda * e_{t-1}
|
||||
s = (guidance_eps - prev_eps) + lam * prev_eps
|
||||
|
||||
# Smooth switching via tanh(s/phi) instead of hard sign(s).
|
||||
# The paper uses sign(s) which works in DiffSynth but creates
|
||||
# salt-and-pepper artifacts in ComfyUI's latent space. tanh
|
||||
# provides smooth spatial gradients: proportional near zero,
|
||||
# saturating at ±K for large |s|.
|
||||
phi = s.std().clamp(min=1e-6)
|
||||
u_sw = -K * torch.tanh(s / phi)
|
||||
# Scale K so the output correction (cond_scale * K_eff) stays
|
||||
# in the range the paper tested (0.5–1.5). The paper only
|
||||
# tested up to cfg=7.5; at cfg=12, full K gives cond_scale*K=2.4
|
||||
# which corrupts the image. sqrt scaling keeps the output
|
||||
# correction growing sub-linearly with cfg.
|
||||
K_eff = K / max(cond_scale, 1.0) ** 0.5
|
||||
|
||||
# Spatial smoothing: blur the correction to remove per-element
|
||||
# grid artifacts at VAE patch boundaries (each latent = 8x8 px).
|
||||
# Smooth switching via tanh instead of sign to avoid
|
||||
# salt-and-pepper artifacts in ComfyUI's latent space.
|
||||
phi = s.std().clamp(min=1e-6)
|
||||
u_sw = -K_eff * torch.tanh(s / phi)
|
||||
|
||||
# Spatial blur to smooth grid artifacts at VAE patch boundaries.
|
||||
if u_sw.ndim == 4:
|
||||
u_sw = F.avg_pool2d(u_sw, kernel_size=5, stride=1, padding=2)
|
||||
|
||||
# Store RAW guidance (before correction) for the next step's
|
||||
# sliding surface. The paper stores corrected guidance, but in
|
||||
# ComfyUI the corrections accumulate through the surface's
|
||||
# lambda * prev_eps term (amplified 4x per step at lambda=5),
|
||||
# overwhelming the actual guidance signal after a few steps.
|
||||
# Storing raw guidance keeps the surface tracking the model's
|
||||
# actual guidance evolution while applying corrections fresh.
|
||||
# Store RAW guidance (before correction) to prevent correction
|
||||
# accumulation through the lambda * prev_eps term.
|
||||
state["prev_eps"] = guidance_eps.detach().clone()
|
||||
|
||||
# Apply correction and convert back to sigma-scaled space
|
||||
return uncond + cond_scale * (guidance_eps + u_sw) * sigma_val
|
||||
|
||||
m = model.clone()
|
||||
|
||||
Reference in New Issue
Block a user