Add xformers compatibility shim using PyTorch native SDPA

Avoids requiring xformers installation by shimming
xformers.ops.memory_efficient_attention with
torch.nn.functional.scaled_dot_product_attention when
xformers is not available.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 23:58:55 +01:00
parent 5786ab6be7
commit 5de26d8ead
2 changed files with 41 additions and 0 deletions

View File

@@ -68,6 +68,24 @@ sys.modules["comfy"] = _comfy
sys.modules["comfy.utils"] = _comfy_utils
sys.modules["comfy.model_management"] = _comfy_mm
# ── xformers compatibility shim (use PyTorch native SDPA if unavailable) ──
if "xformers" not in sys.modules:
try:
import xformers # noqa: F401
except ImportError:
_xformers = types.ModuleType("xformers")
_xformers_ops = types.ModuleType("xformers.ops")
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias,
)
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
_xformers.ops = _xformers_ops
sys.modules["xformers"] = _xformers
sys.modules["xformers.ops"] = _xformers_ops
# ── Standard imports ────────────────────────────────────────────────────
import argparse # noqa: E402
import json # noqa: E402

View File

@@ -27,6 +27,29 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")):
if STAR_REPO not in sys.path:
sys.path.insert(0, STAR_REPO)
# Provide an xformers compatibility shim using PyTorch's native SDPA if xformers
# is not installed. The STAR UNet only uses xformers.ops.memory_efficient_attention
# which is functionally equivalent to torch.nn.functional.scaled_dot_product_attention.
if "xformers" not in sys.modules:
try:
import xformers # noqa: F401
except ImportError:
import types
_xformers = types.ModuleType("xformers")
_xformers_ops = types.ModuleType("xformers.ops")
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias,
)
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
_xformers.ops = _xformers_ops
sys.modules["xformers"] = _xformers
sys.modules["xformers.ops"] = _xformers_ops
print("[STAR] xformers not found — using PyTorch native SDPA as fallback.")
# Known models on HuggingFace that can be auto-downloaded.
HF_REPO = "SherryX/STAR"
HF_MODELS = {