diff --git a/inference.py b/inference.py index 0be4f1c..39d531c 100755 --- a/inference.py +++ b/inference.py @@ -68,6 +68,24 @@ sys.modules["comfy"] = _comfy sys.modules["comfy.utils"] = _comfy_utils sys.modules["comfy.model_management"] = _comfy_mm +# ── xformers compatibility shim (use PyTorch native SDPA if unavailable) ── +if "xformers" not in sys.modules: + try: + import xformers # noqa: F401 + except ImportError: + _xformers = types.ModuleType("xformers") + _xformers_ops = types.ModuleType("xformers.ops") + + def _memory_efficient_attention(q, k, v, attn_bias=None, op=None): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias, + ) + + _xformers_ops.memory_efficient_attention = _memory_efficient_attention + _xformers.ops = _xformers_ops + sys.modules["xformers"] = _xformers + sys.modules["xformers.ops"] = _xformers_ops + # ── Standard imports ──────────────────────────────────────────────────── import argparse # noqa: E402 import json # noqa: E402 diff --git a/nodes.py b/nodes.py index 8710992..e0e061c 100644 --- a/nodes.py +++ b/nodes.py @@ -27,6 +27,29 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")): if STAR_REPO not in sys.path: sys.path.insert(0, STAR_REPO) +# Provide an xformers compatibility shim using PyTorch's native SDPA if xformers +# is not installed. The STAR UNet only uses xformers.ops.memory_efficient_attention +# which is functionally equivalent to torch.nn.functional.scaled_dot_product_attention. +if "xformers" not in sys.modules: + try: + import xformers # noqa: F401 + except ImportError: + import types + + _xformers = types.ModuleType("xformers") + _xformers_ops = types.ModuleType("xformers.ops") + + def _memory_efficient_attention(q, k, v, attn_bias=None, op=None): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias, + ) + + _xformers_ops.memory_efficient_attention = _memory_efficient_attention + _xformers.ops = _xformers_ops + sys.modules["xformers"] = _xformers + sys.modules["xformers.ops"] = _xformers_ops + print("[STAR] xformers not found — using PyTorch native SDPA as fallback.") + # Known models on HuggingFace that can be auto-downloaded. HF_REPO = "SherryX/STAR" HF_MODELS = {