Fix sigma-scaling bug causing noisy images

ComfyUI's args["cond"]/["uncond"] are (x - denoised), which are
sigma-scaled. At late denoising steps (sigma~0.01), the fixed K=0.2
correction was 200x the signal magnitude, destroying the image.

Fix: compute SMC in denoised space using args["cond_denoised"] and
args["uncond_denoised"], which have consistent magnitude across all
sigma values — matching the paper's noise-prediction space.

Also fixes first-step behavior to match the original paper (SMC
correction applied from step 0, not step 1).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-04 17:33:05 +01:00
parent 3953d97163
commit 612e7e973f

View File

@@ -47,9 +47,14 @@ class SMCCFGCtrl:
K = smc_cfg_K K = smc_cfg_K
def smc_cfg_function(args): def smc_cfg_function(args):
cond = args["cond"] # Use denoised-space predictions — these have consistent magnitude
uncond = args["uncond"] # across sigma values. ComfyUI's args["cond"]/["uncond"] are
# (x - denoised), which are sigma-scaled and would make the fixed
# K correction dominate at low sigma (late steps), destroying the image.
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
cond_scale = args["cond_scale"] cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"] sigma = args["sigma"]
# Detect new generation: sigma should decrease monotonically during # Detect new generation: sigma should decrease monotonically during
@@ -65,28 +70,35 @@ class SMCCFGCtrl:
# Warmup: pure conditional prediction (no guidance) # Warmup: pure conditional prediction (no guidance)
if warmup_steps > 0 and step < warmup_steps: if warmup_steps > 0 and step < warmup_steps:
return cond return x - cond_denoised
# Guidance error: e_t = noise_cond - noise_uncond # Guidance error in denoised space (consistent magnitude across sigma)
guidance_eps = cond - uncond guidance_eps = cond_denoised - uncond_denoised
if state["prev_eps"] is not None: # Initialize prev_eps on first SMC step (matches original paper
prev_eps = state["prev_eps"] # where SMC correction is applied from the very first step)
if state["prev_eps"] is None:
state["prev_eps"] = guidance_eps.detach().clone()
# Sliding surface: s_t = (e_t - e_{t-1}) + lambda * e_{t-1} prev_eps = state["prev_eps"]
s = (guidance_eps - prev_eps) + lam * prev_eps
# Switching control: u_sw = -K * sign(s_t) # Sliding surface: s_t = (e_t - e_{t-1}) + lambda * e_{t-1}
u_sw = -K * torch.sign(s) s = (guidance_eps - prev_eps) + lam * prev_eps
# Apply correction to guidance error # Switching control: u_sw = -K * sign(s_t)
guidance_eps = guidance_eps + u_sw u_sw = -K * torch.sign(s)
# Corrected guidance error
guidance_eps = guidance_eps + u_sw
# Store corrected guidance for next step's sliding surface # Store corrected guidance for next step's sliding surface
state["prev_eps"] = guidance_eps.detach().clone() state["prev_eps"] = guidance_eps.detach().clone()
# v_guided = v_uncond + scale * corrected_guidance # Guided denoised output
return uncond + cond_scale * guidance_eps denoised = uncond_denoised + cond_scale * guidance_eps
# Return noise residual (framework computes cfg_result = x - return)
return x - denoised
m = model.clone() m = model.clone()
m.set_model_sampler_cfg_function(smc_cfg_function, disable_cfg1_optimization=True) m.set_model_sampler_cfg_function(smc_cfg_function, disable_cfg1_optimization=True)