diff --git a/nodes.py b/nodes.py index 4cfd450..9b1532b 100644 --- a/nodes.py +++ b/nodes.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F class SMCCFGCtrl: @@ -101,6 +102,13 @@ class SMCCFGCtrl: phi = s.std().clamp(min=1e-6) u_sw = -K_eff * torch.tanh(s / phi) + # Spatial smoothing: the per-element correction creates a grid + # pattern at latent boundaries (each latent = 8x8 pixels). A mild + # 3x3 box blur in latent space removes these artifacts while + # preserving the large-scale correction structure. + if u_sw.ndim == 4: + u_sw = F.avg_pool2d(u_sw, kernel_size=3, stride=1, padding=1) + # Corrected guidance error (in normalized noise space) guidance_eps = guidance_eps + u_sw