Fix noise level (900 not 1000) and prompt concatenation to match original STAR

The original STAR inference uses total_noise_levels=900, preserving input
structure during SDEdit. We had 1000 which starts from near-pure noise,
destroying the input. Also always append the quality prompt to user text
instead of using it only as a fallback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 02:03:34 +01:00
parent 2bf8db4f07
commit 8a440761d1
3 changed files with 14 additions and 28 deletions

View File

@@ -97,26 +97,19 @@ def _dispatched_mea(q, k, v, attn_bias=None, op=None):
if _active_attn == "xformers":
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
if _active_attn == "math":
# Naive batched attention with fp32 accumulation (matches xformers).
orig_dtype = q.dtype
# Naive batched attention — slow but guaranteed correct.
scale = q.shape[-1] ** -0.5
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * 4))
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
outs = []
for i in range(0, q.shape[0], cs):
qi, ki, vi = q[i:i+cs].float(), k[i:i+cs].float(), v[i:i+cs].float()
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
a = torch.bmm(qi * scale, ki.transpose(1, 2))
if attn_bias is not None:
bias = attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias
a = a + bias.float()
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs).to(orig_dtype)
return torch.cat(outs)
if _active_attn == "sdpa":
# xformers uses fp32 accumulation internally; match that by computing
# in fp32 and casting back.
orig_dtype = q.dtype
bias = attn_bias.float() if attn_bias is not None else None
out = F.scaled_dot_product_attention(q.float(), k.float(), v.float(), attn_mask=bias)
return out.to(orig_dtype)
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
fn = _ATTN_BACKENDS[_active_attn]
return fn(