Add xformers compatibility shim using PyTorch native SDPA
Avoids requiring xformers installation by shimming xformers.ops.memory_efficient_attention with torch.nn.functional.scaled_dot_product_attention when xformers is not available. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
23
nodes.py
23
nodes.py
@@ -27,6 +27,29 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")):
|
||||
if STAR_REPO not in sys.path:
|
||||
sys.path.insert(0, STAR_REPO)
|
||||
|
||||
# Provide an xformers compatibility shim using PyTorch's native SDPA if xformers
|
||||
# is not installed. The STAR UNet only uses xformers.ops.memory_efficient_attention
|
||||
# which is functionally equivalent to torch.nn.functional.scaled_dot_product_attention.
|
||||
if "xformers" not in sys.modules:
|
||||
try:
|
||||
import xformers # noqa: F401
|
||||
except ImportError:
|
||||
import types
|
||||
|
||||
_xformers = types.ModuleType("xformers")
|
||||
_xformers_ops = types.ModuleType("xformers.ops")
|
||||
|
||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_bias,
|
||||
)
|
||||
|
||||
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
||||
_xformers.ops = _xformers_ops
|
||||
sys.modules["xformers"] = _xformers
|
||||
sys.modules["xformers.ops"] = _xformers_ops
|
||||
print("[STAR] xformers not found — using PyTorch native SDPA as fallback.")
|
||||
|
||||
# Known models on HuggingFace that can be auto-downloaded.
|
||||
HF_REPO = "SherryX/STAR"
|
||||
HF_MODELS = {
|
||||
|
||||
Reference in New Issue
Block a user