Speed: auto flash-attention/SDPA + document perf levers
transformers .generate() is the slow path; reasoning token volume and swap_eval (2 passes) are the multipliers. Now requests attn_implementation flash_attention_2 -> sdpa -> default automatically (free speedup, flash-attn optional). README gains a Performance section: swap_eval off (biggest free win), flash-attn, smaller model/ fewer axes, avoid nf4 for speed, and vLLM/SGLang as the real production-speed path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+15
-6
@@ -290,14 +290,23 @@ def _load_model(model_path: str, precision: str):
|
||||
else:
|
||||
load_kwargs["dtype"] = torch.bfloat16 if precision == "bf16" else torch.float16
|
||||
|
||||
# Faster attention: flash_attention_2 (needs flash-attn) -> sdpa (built-in) -> default.
|
||||
model, last_err = None, None
|
||||
for cls in candidates:
|
||||
try:
|
||||
model = cls.from_pretrained(model_path, **load_kwargs)
|
||||
for attn in ("flash_attention_2", "sdpa", None):
|
||||
kw = dict(load_kwargs)
|
||||
if attn:
|
||||
kw["attn_implementation"] = attn
|
||||
for cls in candidates:
|
||||
try:
|
||||
model = cls.from_pretrained(model_path, **kw)
|
||||
break
|
||||
except Exception as e: # wrong class OR attn impl unavailable -> try next
|
||||
last_err = e
|
||||
model = None
|
||||
if model is not None:
|
||||
if attn:
|
||||
print(f"[QwenVLImageJudge] attention: {attn}")
|
||||
break
|
||||
except Exception as e: # arch not in this auto class's registry -> try the next
|
||||
last_err = e
|
||||
model = None
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"[QwenVLImageJudge] could not load {model_path} with any of "
|
||||
|
||||
Reference in New Issue
Block a user