Add SageAttention as preferred attention backend when available
Attention fallback chain: SageAttention (2-5x faster, INT8 quantized) > xformers > PyTorch native SDPA. SageAttention is optional — install with `pip install sageattention` for a speed boost. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
20
inference.py
20
inference.py
@@ -68,7 +68,8 @@ sys.modules["comfy"] = _comfy
|
|||||||
sys.modules["comfy.utils"] = _comfy_utils
|
sys.modules["comfy.utils"] = _comfy_utils
|
||||||
sys.modules["comfy.model_management"] = _comfy_mm
|
sys.modules["comfy.model_management"] = _comfy_mm
|
||||||
|
|
||||||
# ── xformers compatibility shim (use PyTorch native SDPA if unavailable) ──
|
# ── xformers compatibility shim ──────────────────────────────────────────
|
||||||
|
# Priority: SageAttention (fastest) > PyTorch native SDPA (always available).
|
||||||
if "xformers" not in sys.modules:
|
if "xformers" not in sys.modules:
|
||||||
try:
|
try:
|
||||||
import xformers # noqa: F401
|
import xformers # noqa: F401
|
||||||
@@ -76,10 +77,19 @@ if "xformers" not in sys.modules:
|
|||||||
_xformers = types.ModuleType("xformers")
|
_xformers = types.ModuleType("xformers")
|
||||||
_xformers_ops = types.ModuleType("xformers.ops")
|
_xformers_ops = types.ModuleType("xformers.ops")
|
||||||
|
|
||||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
try:
|
||||||
return torch.nn.functional.scaled_dot_product_attention(
|
from sageattention import sageattn as _sageattn
|
||||||
q, k, v, attn_mask=attn_bias,
|
|
||||||
)
|
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||||
|
return _sageattn(
|
||||||
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||||
|
tensor_layout="HND", is_causal=False,
|
||||||
|
).squeeze(0)
|
||||||
|
except ImportError:
|
||||||
|
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_bias,
|
||||||
|
)
|
||||||
|
|
||||||
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
||||||
_xformers.ops = _xformers_ops
|
_xformers.ops = _xformers_ops
|
||||||
|
|||||||
29
nodes.py
29
nodes.py
@@ -27,9 +27,9 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")):
|
|||||||
if STAR_REPO not in sys.path:
|
if STAR_REPO not in sys.path:
|
||||||
sys.path.insert(0, STAR_REPO)
|
sys.path.insert(0, STAR_REPO)
|
||||||
|
|
||||||
# Provide an xformers compatibility shim using PyTorch's native SDPA if xformers
|
# Provide an xformers compatibility shim if xformers is not installed.
|
||||||
# is not installed. The STAR UNet only uses xformers.ops.memory_efficient_attention
|
# The STAR UNet only uses xformers.ops.memory_efficient_attention.
|
||||||
# which is functionally equivalent to torch.nn.functional.scaled_dot_product_attention.
|
# Priority: SageAttention (fastest, INT8 quantized) > PyTorch native SDPA (always available).
|
||||||
if "xformers" not in sys.modules:
|
if "xformers" not in sys.modules:
|
||||||
try:
|
try:
|
||||||
import xformers # noqa: F401
|
import xformers # noqa: F401
|
||||||
@@ -39,16 +39,29 @@ if "xformers" not in sys.modules:
|
|||||||
_xformers = types.ModuleType("xformers")
|
_xformers = types.ModuleType("xformers")
|
||||||
_xformers_ops = types.ModuleType("xformers.ops")
|
_xformers_ops = types.ModuleType("xformers.ops")
|
||||||
|
|
||||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
try:
|
||||||
return torch.nn.functional.scaled_dot_product_attention(
|
from sageattention import sageattn as _sageattn
|
||||||
q, k, v, attn_mask=attn_bias,
|
|
||||||
)
|
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||||
|
# STAR UNet passes 3D (B*heads, seq, dim); SageAttention needs 4D.
|
||||||
|
return _sageattn(
|
||||||
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||||
|
tensor_layout="HND", is_causal=False,
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
print("[STAR] xformers not found — using SageAttention (fast INT8 quantized).")
|
||||||
|
except ImportError:
|
||||||
|
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("[STAR] xformers not found — using PyTorch native SDPA as fallback.")
|
||||||
|
|
||||||
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
||||||
_xformers.ops = _xformers_ops
|
_xformers.ops = _xformers_ops
|
||||||
sys.modules["xformers"] = _xformers
|
sys.modules["xformers"] = _xformers
|
||||||
sys.modules["xformers.ops"] = _xformers_ops
|
sys.modules["xformers.ops"] = _xformers_ops
|
||||||
print("[STAR] xformers not found — using PyTorch native SDPA as fallback.")
|
|
||||||
|
|
||||||
# Known models on HuggingFace that can be auto-downloaded.
|
# Known models on HuggingFace that can be auto-downloaded.
|
||||||
HF_REPO = "SherryX/STAR"
|
HF_REPO = "SherryX/STAR"
|
||||||
|
|||||||
Reference in New Issue
Block a user