Add configurable attention backend with SageAttention variant support
Replace the auto-detect xformers shim with a runtime dispatcher that always intercepts xformers.ops.memory_efficient_attention. A new dropdown on STARModelLoader (and --attention CLI arg) lets users explicitly select: sdpa (default), xformers, sageattn, or specific SageAttention kernels (fp16 triton/cuda, fp8 cuda). Only backends that successfully import appear as options. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
89
inference.py
89
inference.py
@@ -68,33 +68,71 @@ sys.modules["comfy"] = _comfy
|
||||
sys.modules["comfy.utils"] = _comfy_utils
|
||||
sys.modules["comfy.model_management"] = _comfy_mm
|
||||
|
||||
# ── xformers compatibility shim ──────────────────────────────────────────
|
||||
# Priority: SageAttention (fastest) > PyTorch native SDPA (always available).
|
||||
if "xformers" not in sys.modules:
|
||||
# ── Attention backend dispatcher ──────────────────────────────────────
|
||||
import torch.nn.functional as F # noqa: E402
|
||||
|
||||
_ATTN_BACKENDS = {"sdpa": None}
|
||||
|
||||
_real_xformers_mea = None
|
||||
try:
|
||||
import xformers.ops
|
||||
_candidate = xformers.ops.memory_efficient_attention
|
||||
if not getattr(_candidate, "_is_star_dispatcher", False):
|
||||
_real_xformers_mea = _candidate
|
||||
_ATTN_BACKENDS["xformers"] = _real_xformers_mea
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_SAGE_VARIANTS = [
|
||||
"sageattn",
|
||||
"sageattn_qk_int8_pv_fp16_triton",
|
||||
"sageattn_qk_int8_pv_fp16_cuda",
|
||||
"sageattn_qk_int8_pv_fp8_cuda",
|
||||
]
|
||||
for _name in _SAGE_VARIANTS:
|
||||
try:
|
||||
import xformers # noqa: F401
|
||||
except ImportError:
|
||||
_xformers = types.ModuleType("xformers")
|
||||
_xformers_ops = types.ModuleType("xformers.ops")
|
||||
_fn = getattr(__import__("sageattention", fromlist=[_name]), _name)
|
||||
_ATTN_BACKENDS[_name] = _fn
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
from sageattention import sageattn as _sageattn
|
||||
_active_attn = "sdpa"
|
||||
|
||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||
return _sageattn(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||
tensor_layout="HND", is_causal=False,
|
||||
).squeeze(0)
|
||||
except ImportError:
|
||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_bias,
|
||||
)
|
||||
|
||||
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
||||
_xformers.ops = _xformers_ops
|
||||
sys.modules["xformers"] = _xformers
|
||||
sys.modules["xformers.ops"] = _xformers_ops
|
||||
def _set_attn(backend: str):
|
||||
global _active_attn
|
||||
if backend not in _ATTN_BACKENDS:
|
||||
print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa")
|
||||
backend = "sdpa"
|
||||
_active_attn = backend
|
||||
print(f"[STAR] Attention backend: {backend}")
|
||||
|
||||
|
||||
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||
if _active_attn == "sdpa":
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||
if _active_attn == "xformers":
|
||||
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
|
||||
fn = _ATTN_BACKENDS[_active_attn]
|
||||
return fn(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||
tensor_layout="HND", is_causal=False,
|
||||
).squeeze(0)
|
||||
|
||||
|
||||
_dispatched_mea._is_star_dispatcher = True
|
||||
|
||||
if "xformers" in sys.modules:
|
||||
sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea
|
||||
else:
|
||||
_xformers = types.ModuleType("xformers")
|
||||
_xformers_ops = types.ModuleType("xformers.ops")
|
||||
_xformers_ops.memory_efficient_attention = _dispatched_mea
|
||||
_xformers.ops = _xformers_ops
|
||||
sys.modules["xformers"] = _xformers
|
||||
sys.modules["xformers.ops"] = _xformers_ops
|
||||
|
||||
print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}")
|
||||
|
||||
# ── Standard imports ────────────────────────────────────────────────────
|
||||
import argparse # noqa: E402
|
||||
@@ -159,6 +197,9 @@ def parse_args():
|
||||
help="Post-processing color correction")
|
||||
g.add_argument("--prompt", default="",
|
||||
help="Text prompt (empty = STAR built-in quality prompt)")
|
||||
g.add_argument("--attention", default="sdpa",
|
||||
choices=list(_ATTN_BACKENDS.keys()),
|
||||
help="Attention backend")
|
||||
|
||||
# -- Video output --
|
||||
g = p.add_argument_group("video output")
|
||||
@@ -583,6 +624,8 @@ def main():
|
||||
print(f"[STAR] Model: {model_path}")
|
||||
star_model = load_model(model_path, args.precision, args.offload, device)
|
||||
|
||||
_set_attn(args.attention)
|
||||
|
||||
# Create writer and process
|
||||
writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single)
|
||||
process_and_stream(star_model, input_frames, writer, args)
|
||||
|
||||
Reference in New Issue
Block a user