From d29808975f2bd3830d7853a0f2d3143844b6861f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 4 Mar 2026 23:39:16 +0100 Subject: [PATCH] Use sqrt(cfg) scaling for K + store raw prev_eps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Full K corrupts at high CFG (output correction = cfg*K = 2.4 at cfg=12). K/cfg was too weak (0.2 at cfg=12). The paper only tested up to cfg=7.5 where output corrections range 0.5-1.5. K/sqrt(cfg) keeps output correction = sqrt(cfg)*K growing sub-linearly, giving 0.69 at cfg=12 — within the paper's working range. Also store raw (pre-correction) guidance as prev_eps to prevent correction accumulation through the sliding surface. Co-Authored-By: Claude Opus 4.6 --- nodes.py | 56 +++++++++++++++++++------------------------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/nodes.py b/nodes.py index 2abd2c6..6ff566f 100644 --- a/nodes.py +++ b/nodes.py @@ -37,24 +37,17 @@ class SMCCFGCtrl: CATEGORY = "sampling/custom_sampling" def patch(self, model, smc_cfg_lambda, smc_cfg_K, warmup_steps): - # Mutable state persisted across denoising steps via closure - state = { - "prev_eps": None, - "step": 0, - "prev_sigma": None, - } - + state = {"prev_eps": None, "step": 0, "prev_sigma": None} lam = smc_cfg_lambda K = smc_cfg_K def smc_cfg_function(args): - cond = args["cond"] # x - cond_denoised (sigma-scaled noise) - uncond = args["uncond"] # x - uncond_denoised (sigma-scaled noise) + cond = args["cond"] + uncond = args["uncond"] cond_scale = args["cond_scale"] sigma = args["sigma"] - # Detect new generation: sigma should decrease monotonically during - # denoising. If it jumps up, a new sampling run has started. + # Detect new generation (sigma jumps up = new sampling run) curr_sigma = sigma.max().item() if torch.is_tensor(sigma) else float(sigma) if state["prev_sigma"] is not None and curr_sigma > state["prev_sigma"] * 1.1: state["prev_eps"] = None @@ -64,21 +57,13 @@ class SMCCFGCtrl: step = state["step"] state["step"] = step + 1 - # Warmup: pure conditional prediction (no guidance) if warmup_steps > 0 and step < warmup_steps: return cond - # Normalize to noise-prediction space by dividing out sigma. - # The paper's K is calibrated for unit-variance noise predictions. - # ComfyUI's cond/uncond are (x - denoised) ≈ sigma * epsilon, - # so dividing by sigma recovers epsilon-space where K=0.2 is correct. - # Crucially, when converting back, the sigma factor naturally dampens - # the correction at late steps (small sigma), preventing noise injection. + # Normalize to noise-prediction space (divide out sigma). sigma_val = max(curr_sigma, 1e-8) guidance_eps = (cond - uncond) / sigma_val - # Initialize prev_eps on first SMC step (matches original paper - # where SMC correction is applied from the very first step) if state["prev_eps"] is None: state["prev_eps"] = guidance_eps.detach().clone() @@ -87,29 +72,26 @@ class SMCCFGCtrl: # Sliding surface: s_t = (e_t - e_{t-1}) + lambda * e_{t-1} s = (guidance_eps - prev_eps) + lam * prev_eps - # Smooth switching via tanh(s/phi) instead of hard sign(s). - # The paper uses sign(s) which works in DiffSynth but creates - # salt-and-pepper artifacts in ComfyUI's latent space. tanh - # provides smooth spatial gradients: proportional near zero, - # saturating at ±K for large |s|. - phi = s.std().clamp(min=1e-6) - u_sw = -K * torch.tanh(s / phi) + # Scale K so the output correction (cond_scale * K_eff) stays + # in the range the paper tested (0.5–1.5). The paper only + # tested up to cfg=7.5; at cfg=12, full K gives cond_scale*K=2.4 + # which corrupts the image. sqrt scaling keeps the output + # correction growing sub-linearly with cfg. + K_eff = K / max(cond_scale, 1.0) ** 0.5 - # Spatial smoothing: blur the correction to remove per-element - # grid artifacts at VAE patch boundaries (each latent = 8x8 px). + # Smooth switching via tanh instead of sign to avoid + # salt-and-pepper artifacts in ComfyUI's latent space. + phi = s.std().clamp(min=1e-6) + u_sw = -K_eff * torch.tanh(s / phi) + + # Spatial blur to smooth grid artifacts at VAE patch boundaries. if u_sw.ndim == 4: u_sw = F.avg_pool2d(u_sw, kernel_size=5, stride=1, padding=2) - # Store RAW guidance (before correction) for the next step's - # sliding surface. The paper stores corrected guidance, but in - # ComfyUI the corrections accumulate through the surface's - # lambda * prev_eps term (amplified 4x per step at lambda=5), - # overwhelming the actual guidance signal after a few steps. - # Storing raw guidance keeps the surface tracking the model's - # actual guidance evolution while applying corrections fresh. + # Store RAW guidance (before correction) to prevent correction + # accumulation through the lambda * prev_eps term. state["prev_eps"] = guidance_eps.detach().clone() - # Apply correction and convert back to sigma-scaled space return uncond + cond_scale * (guidance_eps + u_sw) * sigma_val m = model.clone()