From 8a440761d1c5105214233496701757f3bc8ef9d6 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 15 Feb 2026 02:03:34 +0100 Subject: [PATCH] Fix noise level (900 not 1000) and prompt concatenation to match original STAR The original STAR inference uses total_noise_levels=900, preserving input structure during SDEdit. We had 1000 which starts from near-pure noise, destroying the input. Also always append the quality prompt to user text instead of using it only as a fallback. Co-Authored-By: Claude Opus 4.6 --- inference.py | 19 ++++++------------- nodes.py | 19 ++++++------------- star_pipeline.py | 4 ++-- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/inference.py b/inference.py index 1950141..4095e32 100755 --- a/inference.py +++ b/inference.py @@ -133,26 +133,19 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None): if _active_attn == "xformers": return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) if _active_attn == "math": - # Naive batched attention with fp32 accumulation (matches xformers). - orig_dtype = q.dtype + # Naive batched attention — slow but guaranteed correct. scale = q.shape[-1] ** -0.5 - cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * 4)) + cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1))) outs = [] for i in range(0, q.shape[0], cs): - qi, ki, vi = q[i:i+cs].float(), k[i:i+cs].float(), v[i:i+cs].float() + qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs] a = torch.bmm(qi * scale, ki.transpose(1, 2)) if attn_bias is not None: - bias = attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias - a = a + bias.float() + a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) outs.append(torch.bmm(a.softmax(dim=-1), vi)) - return torch.cat(outs).to(orig_dtype) + return torch.cat(outs) if _active_attn == "sdpa": - # xformers uses fp32 accumulation internally; match that by computing - # in fp32 and casting back. - orig_dtype = q.dtype - bias = attn_bias.float() if attn_bias is not None else None - out = F.scaled_dot_product_attention(q.float(), k.float(), v.float(), attn_mask=bias) - return out.to(orig_dtype) + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) # SageAttention variants: need 4D tensors (batch, heads, seq, dim) fn = _ATTN_BACKENDS[_active_attn] return fn( diff --git a/nodes.py b/nodes.py index 265cf55..c53724c 100644 --- a/nodes.py +++ b/nodes.py @@ -97,26 +97,19 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None): if _active_attn == "xformers": return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) if _active_attn == "math": - # Naive batched attention with fp32 accumulation (matches xformers). - orig_dtype = q.dtype + # Naive batched attention — slow but guaranteed correct. scale = q.shape[-1] ** -0.5 - cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * 4)) + cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1))) outs = [] for i in range(0, q.shape[0], cs): - qi, ki, vi = q[i:i+cs].float(), k[i:i+cs].float(), v[i:i+cs].float() + qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs] a = torch.bmm(qi * scale, ki.transpose(1, 2)) if attn_bias is not None: - bias = attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias - a = a + bias.float() + a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias) outs.append(torch.bmm(a.softmax(dim=-1), vi)) - return torch.cat(outs).to(orig_dtype) + return torch.cat(outs) if _active_attn == "sdpa": - # xformers uses fp32 accumulation internally; match that by computing - # in fp32 and casting back. - orig_dtype = q.dtype - bias = attn_bias.float() if attn_bias is not None else None - out = F.scaled_dot_product_attention(q.float(), k.float(), v.float(), attn_mask=bias) - return out.to(orig_dtype) + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) # SageAttention variants: need 4D tensors (batch, heads, seq, dim) fn = _ATTN_BACKENDS[_active_attn] return fn( diff --git a/star_pipeline.py b/star_pipeline.py index 643faed..3a05fe8 100644 --- a/star_pipeline.py +++ b/star_pipeline.py @@ -265,7 +265,7 @@ def run_star_inference( if offload == "aggressive": vae_dec_chunk = 1 - total_noise_levels = 1000 + total_noise_levels = 900 # -- Convert ComfyUI frames to STAR format -- video_data = comfyui_to_star_frames(images) # [F, 3, H, W] @@ -291,7 +291,7 @@ def run_star_inference( if offload != "disabled": text_encoder.model.to(device) text_encoder.device = device - text = prompt if prompt.strip() else cfg.positive_prompt + text = (prompt if prompt.strip() else "") + cfg.positive_prompt y = text_encoder(text).detach() if offload != "disabled": text_encoder.model.to("cpu")