diff --git a/inference.py b/inference.py index 08b5d9c..afa2d38 100755 --- a/inference.py +++ b/inference.py @@ -68,33 +68,71 @@ sys.modules["comfy"] = _comfy sys.modules["comfy.utils"] = _comfy_utils sys.modules["comfy.model_management"] = _comfy_mm -# ── xformers compatibility shim ────────────────────────────────────────── -# Priority: SageAttention (fastest) > PyTorch native SDPA (always available). -if "xformers" not in sys.modules: +# ── Attention backend dispatcher ────────────────────────────────────── +import torch.nn.functional as F # noqa: E402 + +_ATTN_BACKENDS = {"sdpa": None} + +_real_xformers_mea = None +try: + import xformers.ops + _candidate = xformers.ops.memory_efficient_attention + if not getattr(_candidate, "_is_star_dispatcher", False): + _real_xformers_mea = _candidate + _ATTN_BACKENDS["xformers"] = _real_xformers_mea +except ImportError: + pass + +_SAGE_VARIANTS = [ + "sageattn", + "sageattn_qk_int8_pv_fp16_triton", + "sageattn_qk_int8_pv_fp16_cuda", + "sageattn_qk_int8_pv_fp8_cuda", +] +for _name in _SAGE_VARIANTS: try: - import xformers # noqa: F401 - except ImportError: - _xformers = types.ModuleType("xformers") - _xformers_ops = types.ModuleType("xformers.ops") + _fn = getattr(__import__("sageattention", fromlist=[_name]), _name) + _ATTN_BACKENDS[_name] = _fn + except (ImportError, AttributeError): + pass - try: - from sageattention import sageattn as _sageattn +_active_attn = "sdpa" - def _memory_efficient_attention(q, k, v, attn_bias=None, op=None): - return _sageattn( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), - tensor_layout="HND", is_causal=False, - ).squeeze(0) - except ImportError: - def _memory_efficient_attention(q, k, v, attn_bias=None, op=None): - return torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_bias, - ) - _xformers_ops.memory_efficient_attention = _memory_efficient_attention - _xformers.ops = _xformers_ops - sys.modules["xformers"] = _xformers - sys.modules["xformers.ops"] = _xformers_ops +def _set_attn(backend: str): + global _active_attn + if backend not in _ATTN_BACKENDS: + print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa") + backend = "sdpa" + _active_attn = backend + print(f"[STAR] Attention backend: {backend}") + + +def _dispatched_mea(q, k, v, attn_bias=None, op=None): + if _active_attn == "sdpa": + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + if _active_attn == "xformers": + return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) + fn = _ATTN_BACKENDS[_active_attn] + return fn( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), + tensor_layout="HND", is_causal=False, + ).squeeze(0) + + +_dispatched_mea._is_star_dispatcher = True + +if "xformers" in sys.modules: + sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea +else: + _xformers = types.ModuleType("xformers") + _xformers_ops = types.ModuleType("xformers.ops") + _xformers_ops.memory_efficient_attention = _dispatched_mea + _xformers.ops = _xformers_ops + sys.modules["xformers"] = _xformers + sys.modules["xformers.ops"] = _xformers_ops + +print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}") # ── Standard imports ──────────────────────────────────────────────────── import argparse # noqa: E402 @@ -159,6 +197,9 @@ def parse_args(): help="Post-processing color correction") g.add_argument("--prompt", default="", help="Text prompt (empty = STAR built-in quality prompt)") + g.add_argument("--attention", default="sdpa", + choices=list(_ATTN_BACKENDS.keys()), + help="Attention backend") # -- Video output -- g = p.add_argument_group("video output") @@ -583,6 +624,8 @@ def main(): print(f"[STAR] Model: {model_path}") star_model = load_model(model_path, args.precision, args.offload, device) + _set_attn(args.attention) + # Create writer and process writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single) process_and_stream(star_model, input_frames, writer, args) diff --git a/nodes.py b/nodes.py index ee9527d..1bf579e 100644 --- a/nodes.py +++ b/nodes.py @@ -1,6 +1,8 @@ import os import sys +import types import torch +import torch.nn.functional as F import folder_paths import comfy.model_management as mm @@ -27,41 +29,77 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")): if STAR_REPO not in sys.path: sys.path.insert(0, STAR_REPO) -# Provide an xformers compatibility shim if xformers is not installed. -# The STAR UNet only uses xformers.ops.memory_efficient_attention. -# Priority: SageAttention (fastest, INT8 quantized) > PyTorch native SDPA (always available). -if "xformers" not in sys.modules: +# ── Attention backend dispatcher ────────────────────────────────────── +# Build a registry of available backends at import time. +# sdpa (PyTorch native) is always available and is the default. +_ATTN_BACKENDS = {"sdpa": None} # None = use F.scaled_dot_product_attention directly + +# Try real xformers — guard against capturing our own dispatcher on reload +# or another node's shim by checking for a marker attribute. +_real_xformers_mea = None +try: + import xformers.ops + _candidate = xformers.ops.memory_efficient_attention + if not getattr(_candidate, "_is_star_dispatcher", False): + _real_xformers_mea = _candidate + _ATTN_BACKENDS["xformers"] = _real_xformers_mea +except ImportError: + pass + +# Try SageAttention variants +_SAGE_VARIANTS = [ + "sageattn", + "sageattn_qk_int8_pv_fp16_triton", + "sageattn_qk_int8_pv_fp16_cuda", + "sageattn_qk_int8_pv_fp8_cuda", +] +for _name in _SAGE_VARIANTS: try: - import xformers # noqa: F401 - except ImportError: - import types + _fn = getattr(__import__("sageattention", fromlist=[_name]), _name) + _ATTN_BACKENDS[_name] = _fn + except (ImportError, AttributeError): + pass - _xformers = types.ModuleType("xformers") - _xformers_ops = types.ModuleType("xformers.ops") +_active_attn = "sdpa" - try: - from sageattention import sageattn as _sageattn - def _memory_efficient_attention(q, k, v, attn_bias=None, op=None): - # STAR UNet passes 3D (B*heads, seq, dim); SageAttention needs 4D. - return _sageattn( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), - tensor_layout="HND", is_causal=False, - ).squeeze(0) +def _set_attn(backend: str): + global _active_attn + if backend not in _ATTN_BACKENDS: + print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa") + backend = "sdpa" + _active_attn = backend + print(f"[STAR] Attention backend: {backend}") - print("[STAR] xformers not found — using SageAttention (fast INT8 quantized).") - except ImportError: - def _memory_efficient_attention(q, k, v, attn_bias=None, op=None): - return torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_bias, - ) - print("[STAR] xformers not found — using PyTorch native SDPA as fallback.") +def _dispatched_mea(q, k, v, attn_bias=None, op=None): + if _active_attn == "sdpa": + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + if _active_attn == "xformers": + return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op) + # SageAttention variants: need 4D tensors (batch, heads, seq, dim) + fn = _ATTN_BACKENDS[_active_attn] + return fn( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), + tensor_layout="HND", is_causal=False, + ).squeeze(0) - _xformers_ops.memory_efficient_attention = _memory_efficient_attention - _xformers.ops = _xformers_ops - sys.modules["xformers"] = _xformers - sys.modules["xformers.ops"] = _xformers_ops + +_dispatched_mea._is_star_dispatcher = True + +# Always install the dispatcher as xformers.ops.memory_efficient_attention +# so the dropdown controls what actually runs regardless of real xformers. +if "xformers" in sys.modules: + sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea +else: + _xformers = types.ModuleType("xformers") + _xformers_ops = types.ModuleType("xformers.ops") + _xformers_ops.memory_efficient_attention = _dispatched_mea + _xformers.ops = _xformers_ops + sys.modules["xformers"] = _xformers + sys.modules["xformers.ops"] = _xformers_ops + +print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}") # Known models on HuggingFace that can be auto-downloaded. HF_REPO = "SherryX/STAR" @@ -121,6 +159,12 @@ class STARModelLoader: "default": "disabled", "tooltip": "disabled: all on GPU (~39GB). model: swap UNet/VAE/CLIP to CPU when idle (~16GB). aggressive: model offload + single-frame VAE decode (~12GB).", }), + "attention": (list(_ATTN_BACKENDS.keys()), { + "default": "sdpa", + "tooltip": "Attention backend. sdpa: PyTorch native (default, always available). " + "xformers: original backend. sageattn: SageAttention auto-select. " + "Other sageattn_* variants: specific SageAttention kernels for fine-tuning speed/precision.", + }), } } @@ -130,7 +174,7 @@ class STARModelLoader: CATEGORY = "STAR" DESCRIPTION = "Loads the STAR video super-resolution model (UNet+ControlNet, OpenCLIP text encoder, temporal VAE). All components are auto-downloaded on first use." - def load_model(self, model_name, precision, offload="disabled"): + def load_model(self, model_name, precision, offload="disabled", attention="sdpa"): device = mm.get_torch_device() dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} dtype = dtype_map[precision] @@ -204,6 +248,7 @@ class STARModelLoader: "device": device, "dtype": dtype, "offload": offload, + "attention": attention, } return (star_model,) @@ -278,6 +323,8 @@ class STARVideoSuperResolution: color_fix, segment_size=0, ): + _set_attn(star_model.get("attention", "sdpa")) + kwargs = dict( star_model=star_model, images=images,