diff --git a/nodes.py b/nodes.py index 7f52c64..bbe62e6 100644 --- a/nodes.py +++ b/nodes.py @@ -47,9 +47,14 @@ class SMCCFGCtrl: K = smc_cfg_K def smc_cfg_function(args): - cond = args["cond"] - uncond = args["uncond"] + # Use denoised-space predictions — these have consistent magnitude + # across sigma values. ComfyUI's args["cond"]/["uncond"] are + # (x - denoised), which are sigma-scaled and would make the fixed + # K correction dominate at low sigma (late steps), destroying the image. + cond_denoised = args["cond_denoised"] + uncond_denoised = args["uncond_denoised"] cond_scale = args["cond_scale"] + x = args["input"] sigma = args["sigma"] # Detect new generation: sigma should decrease monotonically during @@ -65,28 +70,35 @@ class SMCCFGCtrl: # Warmup: pure conditional prediction (no guidance) if warmup_steps > 0 and step < warmup_steps: - return cond + return x - cond_denoised - # Guidance error: e_t = noise_cond - noise_uncond - guidance_eps = cond - uncond + # Guidance error in denoised space (consistent magnitude across sigma) + guidance_eps = cond_denoised - uncond_denoised - if state["prev_eps"] is not None: - prev_eps = state["prev_eps"] + # Initialize prev_eps on first SMC step (matches original paper + # where SMC correction is applied from the very first step) + if state["prev_eps"] is None: + state["prev_eps"] = guidance_eps.detach().clone() - # Sliding surface: s_t = (e_t - e_{t-1}) + lambda * e_{t-1} - s = (guidance_eps - prev_eps) + lam * prev_eps + prev_eps = state["prev_eps"] - # Switching control: u_sw = -K * sign(s_t) - u_sw = -K * torch.sign(s) + # Sliding surface: s_t = (e_t - e_{t-1}) + lambda * e_{t-1} + s = (guidance_eps - prev_eps) + lam * prev_eps - # Apply correction to guidance error - guidance_eps = guidance_eps + u_sw + # Switching control: u_sw = -K * sign(s_t) + u_sw = -K * torch.sign(s) + + # Corrected guidance error + guidance_eps = guidance_eps + u_sw # Store corrected guidance for next step's sliding surface state["prev_eps"] = guidance_eps.detach().clone() - # v_guided = v_uncond + scale * corrected_guidance - return uncond + cond_scale * guidance_eps + # Guided denoised output + denoised = uncond_denoised + cond_scale * guidance_eps + + # Return noise residual (framework computes cfg_result = x - return) + return x - denoised m = model.clone() m.set_model_sampler_cfg_function(smc_cfg_function, disable_cfg1_optimization=True)