Add configurable attention backend with SageAttention variant support
Replace the auto-detect xformers shim with a runtime dispatcher that always intercepts xformers.ops.memory_efficient_attention. A new dropdown on STARModelLoader (and --attention CLI arg) lets users explicitly select: sdpa (default), xformers, sageattn, or specific SageAttention kernels (fp16 triton/cuda, fp8 cuda). Only backends that successfully import appear as options. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
73
inference.py
73
inference.py
@@ -68,34 +68,72 @@ sys.modules["comfy"] = _comfy
|
|||||||
sys.modules["comfy.utils"] = _comfy_utils
|
sys.modules["comfy.utils"] = _comfy_utils
|
||||||
sys.modules["comfy.model_management"] = _comfy_mm
|
sys.modules["comfy.model_management"] = _comfy_mm
|
||||||
|
|
||||||
# ── xformers compatibility shim ──────────────────────────────────────────
|
# ── Attention backend dispatcher ──────────────────────────────────────
|
||||||
# Priority: SageAttention (fastest) > PyTorch native SDPA (always available).
|
import torch.nn.functional as F # noqa: E402
|
||||||
if "xformers" not in sys.modules:
|
|
||||||
|
_ATTN_BACKENDS = {"sdpa": None}
|
||||||
|
|
||||||
|
_real_xformers_mea = None
|
||||||
try:
|
try:
|
||||||
import xformers # noqa: F401
|
import xformers.ops
|
||||||
|
_candidate = xformers.ops.memory_efficient_attention
|
||||||
|
if not getattr(_candidate, "_is_star_dispatcher", False):
|
||||||
|
_real_xformers_mea = _candidate
|
||||||
|
_ATTN_BACKENDS["xformers"] = _real_xformers_mea
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_xformers = types.ModuleType("xformers")
|
pass
|
||||||
_xformers_ops = types.ModuleType("xformers.ops")
|
|
||||||
|
|
||||||
|
_SAGE_VARIANTS = [
|
||||||
|
"sageattn",
|
||||||
|
"sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
]
|
||||||
|
for _name in _SAGE_VARIANTS:
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn as _sageattn
|
_fn = getattr(__import__("sageattention", fromlist=[_name]), _name)
|
||||||
|
_ATTN_BACKENDS[_name] = _fn
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
_active_attn = "sdpa"
|
||||||
return _sageattn(
|
|
||||||
|
|
||||||
|
def _set_attn(backend: str):
|
||||||
|
global _active_attn
|
||||||
|
if backend not in _ATTN_BACKENDS:
|
||||||
|
print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa")
|
||||||
|
backend = "sdpa"
|
||||||
|
_active_attn = backend
|
||||||
|
print(f"[STAR] Attention backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||||
|
if _active_attn == "sdpa":
|
||||||
|
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||||
|
if _active_attn == "xformers":
|
||||||
|
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
|
||||||
|
fn = _ATTN_BACKENDS[_active_attn]
|
||||||
|
return fn(
|
||||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||||
tensor_layout="HND", is_causal=False,
|
tensor_layout="HND", is_causal=False,
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
except ImportError:
|
|
||||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
|
||||||
return torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
q, k, v, attn_mask=attn_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
|
||||||
|
_dispatched_mea._is_star_dispatcher = True
|
||||||
|
|
||||||
|
if "xformers" in sys.modules:
|
||||||
|
sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea
|
||||||
|
else:
|
||||||
|
_xformers = types.ModuleType("xformers")
|
||||||
|
_xformers_ops = types.ModuleType("xformers.ops")
|
||||||
|
_xformers_ops.memory_efficient_attention = _dispatched_mea
|
||||||
_xformers.ops = _xformers_ops
|
_xformers.ops = _xformers_ops
|
||||||
sys.modules["xformers"] = _xformers
|
sys.modules["xformers"] = _xformers
|
||||||
sys.modules["xformers.ops"] = _xformers_ops
|
sys.modules["xformers.ops"] = _xformers_ops
|
||||||
|
|
||||||
|
print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}")
|
||||||
|
|
||||||
# ── Standard imports ────────────────────────────────────────────────────
|
# ── Standard imports ────────────────────────────────────────────────────
|
||||||
import argparse # noqa: E402
|
import argparse # noqa: E402
|
||||||
import json # noqa: E402
|
import json # noqa: E402
|
||||||
@@ -159,6 +197,9 @@ def parse_args():
|
|||||||
help="Post-processing color correction")
|
help="Post-processing color correction")
|
||||||
g.add_argument("--prompt", default="",
|
g.add_argument("--prompt", default="",
|
||||||
help="Text prompt (empty = STAR built-in quality prompt)")
|
help="Text prompt (empty = STAR built-in quality prompt)")
|
||||||
|
g.add_argument("--attention", default="sdpa",
|
||||||
|
choices=list(_ATTN_BACKENDS.keys()),
|
||||||
|
help="Attention backend")
|
||||||
|
|
||||||
# -- Video output --
|
# -- Video output --
|
||||||
g = p.add_argument_group("video output")
|
g = p.add_argument_group("video output")
|
||||||
@@ -583,6 +624,8 @@ def main():
|
|||||||
print(f"[STAR] Model: {model_path}")
|
print(f"[STAR] Model: {model_path}")
|
||||||
star_model = load_model(model_path, args.precision, args.offload, device)
|
star_model = load_model(model_path, args.precision, args.offload, device)
|
||||||
|
|
||||||
|
_set_attn(args.attention)
|
||||||
|
|
||||||
# Create writer and process
|
# Create writer and process
|
||||||
writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single)
|
writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single)
|
||||||
process_and_stream(star_model, input_frames, writer, args)
|
process_and_stream(star_model, input_frames, writer, args)
|
||||||
|
|||||||
91
nodes.py
91
nodes.py
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import types
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
|
||||||
@@ -27,42 +29,78 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")):
|
|||||||
if STAR_REPO not in sys.path:
|
if STAR_REPO not in sys.path:
|
||||||
sys.path.insert(0, STAR_REPO)
|
sys.path.insert(0, STAR_REPO)
|
||||||
|
|
||||||
# Provide an xformers compatibility shim if xformers is not installed.
|
# ── Attention backend dispatcher ──────────────────────────────────────
|
||||||
# The STAR UNet only uses xformers.ops.memory_efficient_attention.
|
# Build a registry of available backends at import time.
|
||||||
# Priority: SageAttention (fastest, INT8 quantized) > PyTorch native SDPA (always available).
|
# sdpa (PyTorch native) is always available and is the default.
|
||||||
if "xformers" not in sys.modules:
|
_ATTN_BACKENDS = {"sdpa": None} # None = use F.scaled_dot_product_attention directly
|
||||||
|
|
||||||
|
# Try real xformers — guard against capturing our own dispatcher on reload
|
||||||
|
# or another node's shim by checking for a marker attribute.
|
||||||
|
_real_xformers_mea = None
|
||||||
try:
|
try:
|
||||||
import xformers # noqa: F401
|
import xformers.ops
|
||||||
|
_candidate = xformers.ops.memory_efficient_attention
|
||||||
|
if not getattr(_candidate, "_is_star_dispatcher", False):
|
||||||
|
_real_xformers_mea = _candidate
|
||||||
|
_ATTN_BACKENDS["xformers"] = _real_xformers_mea
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import types
|
pass
|
||||||
|
|
||||||
_xformers = types.ModuleType("xformers")
|
|
||||||
_xformers_ops = types.ModuleType("xformers.ops")
|
|
||||||
|
|
||||||
|
# Try SageAttention variants
|
||||||
|
_SAGE_VARIANTS = [
|
||||||
|
"sageattn",
|
||||||
|
"sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
]
|
||||||
|
for _name in _SAGE_VARIANTS:
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn as _sageattn
|
_fn = getattr(__import__("sageattention", fromlist=[_name]), _name)
|
||||||
|
_ATTN_BACKENDS[_name] = _fn
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
_active_attn = "sdpa"
|
||||||
# STAR UNet passes 3D (B*heads, seq, dim); SageAttention needs 4D.
|
|
||||||
return _sageattn(
|
|
||||||
|
def _set_attn(backend: str):
|
||||||
|
global _active_attn
|
||||||
|
if backend not in _ATTN_BACKENDS:
|
||||||
|
print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa")
|
||||||
|
backend = "sdpa"
|
||||||
|
_active_attn = backend
|
||||||
|
print(f"[STAR] Attention backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||||||
|
if _active_attn == "sdpa":
|
||||||
|
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||||
|
if _active_attn == "xformers":
|
||||||
|
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
|
||||||
|
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
|
||||||
|
fn = _ATTN_BACKENDS[_active_attn]
|
||||||
|
return fn(
|
||||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||||||
tensor_layout="HND", is_causal=False,
|
tensor_layout="HND", is_causal=False,
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
print("[STAR] xformers not found — using SageAttention (fast INT8 quantized).")
|
|
||||||
except ImportError:
|
|
||||||
def _memory_efficient_attention(q, k, v, attn_bias=None, op=None):
|
|
||||||
return torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
q, k, v, attn_mask=attn_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("[STAR] xformers not found — using PyTorch native SDPA as fallback.")
|
_dispatched_mea._is_star_dispatcher = True
|
||||||
|
|
||||||
_xformers_ops.memory_efficient_attention = _memory_efficient_attention
|
# Always install the dispatcher as xformers.ops.memory_efficient_attention
|
||||||
|
# so the dropdown controls what actually runs regardless of real xformers.
|
||||||
|
if "xformers" in sys.modules:
|
||||||
|
sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea
|
||||||
|
else:
|
||||||
|
_xformers = types.ModuleType("xformers")
|
||||||
|
_xformers_ops = types.ModuleType("xformers.ops")
|
||||||
|
_xformers_ops.memory_efficient_attention = _dispatched_mea
|
||||||
_xformers.ops = _xformers_ops
|
_xformers.ops = _xformers_ops
|
||||||
sys.modules["xformers"] = _xformers
|
sys.modules["xformers"] = _xformers
|
||||||
sys.modules["xformers.ops"] = _xformers_ops
|
sys.modules["xformers.ops"] = _xformers_ops
|
||||||
|
|
||||||
|
print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}")
|
||||||
|
|
||||||
# Known models on HuggingFace that can be auto-downloaded.
|
# Known models on HuggingFace that can be auto-downloaded.
|
||||||
HF_REPO = "SherryX/STAR"
|
HF_REPO = "SherryX/STAR"
|
||||||
HF_MODELS = {
|
HF_MODELS = {
|
||||||
@@ -121,6 +159,12 @@ class STARModelLoader:
|
|||||||
"default": "disabled",
|
"default": "disabled",
|
||||||
"tooltip": "disabled: all on GPU (~39GB). model: swap UNet/VAE/CLIP to CPU when idle (~16GB). aggressive: model offload + single-frame VAE decode (~12GB).",
|
"tooltip": "disabled: all on GPU (~39GB). model: swap UNet/VAE/CLIP to CPU when idle (~16GB). aggressive: model offload + single-frame VAE decode (~12GB).",
|
||||||
}),
|
}),
|
||||||
|
"attention": (list(_ATTN_BACKENDS.keys()), {
|
||||||
|
"default": "sdpa",
|
||||||
|
"tooltip": "Attention backend. sdpa: PyTorch native (default, always available). "
|
||||||
|
"xformers: original backend. sageattn: SageAttention auto-select. "
|
||||||
|
"Other sageattn_* variants: specific SageAttention kernels for fine-tuning speed/precision.",
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,7 +174,7 @@ class STARModelLoader:
|
|||||||
CATEGORY = "STAR"
|
CATEGORY = "STAR"
|
||||||
DESCRIPTION = "Loads the STAR video super-resolution model (UNet+ControlNet, OpenCLIP text encoder, temporal VAE). All components are auto-downloaded on first use."
|
DESCRIPTION = "Loads the STAR video super-resolution model (UNet+ControlNet, OpenCLIP text encoder, temporal VAE). All components are auto-downloaded on first use."
|
||||||
|
|
||||||
def load_model(self, model_name, precision, offload="disabled"):
|
def load_model(self, model_name, precision, offload="disabled", attention="sdpa"):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
||||||
dtype = dtype_map[precision]
|
dtype = dtype_map[precision]
|
||||||
@@ -204,6 +248,7 @@ class STARModelLoader:
|
|||||||
"device": device,
|
"device": device,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"offload": offload,
|
"offload": offload,
|
||||||
|
"attention": attention,
|
||||||
}
|
}
|
||||||
return (star_model,)
|
return (star_model,)
|
||||||
|
|
||||||
@@ -278,6 +323,8 @@ class STARVideoSuperResolution:
|
|||||||
color_fix,
|
color_fix,
|
||||||
segment_size=0,
|
segment_size=0,
|
||||||
):
|
):
|
||||||
|
_set_attn(star_model.get("attention", "sdpa"))
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
star_model=star_model,
|
star_model=star_model,
|
||||||
images=images,
|
images=images,
|
||||||
|
|||||||
Reference in New Issue
Block a user