Fix open_clip batch_first compatibility via auto-applied patch
Newer open_clip creates nn.MultiheadAttention with batch_first=True, but STAR's embedder unconditionally permutes to [seq, batch, embed]. This causes a RuntimeError in the text encoder (attn_mask shape mismatch). The patch detects batch_first at runtime and only permutes when needed. Patches in patches/ are auto-applied to the STAR submodule on startup and skip gracefully if already applied. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
21
inference.py
21
inference.py
@@ -29,8 +29,26 @@ import types
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||||
|
STAR_REPO = SCRIPT_DIR / "STAR"
|
||||||
sys.path.insert(0, str(SCRIPT_DIR))
|
sys.path.insert(0, str(SCRIPT_DIR))
|
||||||
sys.path.insert(0, str(SCRIPT_DIR / "STAR"))
|
sys.path.insert(0, str(STAR_REPO))
|
||||||
|
|
||||||
|
# Apply patches from patches/ directory to the STAR submodule.
|
||||||
|
import subprocess # noqa: E402
|
||||||
|
|
||||||
|
_PATCHES_DIR = SCRIPT_DIR / "patches"
|
||||||
|
if _PATCHES_DIR.is_dir():
|
||||||
|
for _patch in sorted(_PATCHES_DIR.iterdir()):
|
||||||
|
if _patch.suffix != ".patch":
|
||||||
|
continue
|
||||||
|
if subprocess.call(
|
||||||
|
["git", "apply", "--check", "--reverse", str(_patch)],
|
||||||
|
cwd=str(STAR_REPO), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
||||||
|
) != 0:
|
||||||
|
if subprocess.call(["git", "apply", str(_patch)], cwd=str(STAR_REPO)) == 0:
|
||||||
|
print(f"[STAR] Applied patch: {_patch.name}")
|
||||||
|
else:
|
||||||
|
print(f"[STAR] Warning: failed to apply patch: {_patch.name}")
|
||||||
|
|
||||||
import torch # noqa: E402 — needed for stub defaults
|
import torch # noqa: E402 — needed for stub defaults
|
||||||
|
|
||||||
@@ -138,7 +156,6 @@ print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}")
|
|||||||
import argparse # noqa: E402
|
import argparse # noqa: E402
|
||||||
import json # noqa: E402
|
import json # noqa: E402
|
||||||
import shutil # noqa: E402
|
import shutil # noqa: E402
|
||||||
import subprocess # noqa: E402
|
|
||||||
|
|
||||||
import numpy as np # noqa: E402
|
import numpy as np # noqa: E402
|
||||||
from PIL import Image # noqa: E402
|
from PIL import Image # noqa: E402
|
||||||
|
|||||||
18
nodes.py
18
nodes.py
@@ -29,6 +29,24 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")):
|
|||||||
if STAR_REPO not in sys.path:
|
if STAR_REPO not in sys.path:
|
||||||
sys.path.insert(0, STAR_REPO)
|
sys.path.insert(0, STAR_REPO)
|
||||||
|
|
||||||
|
# Apply patches from patches/ directory to the STAR submodule.
|
||||||
|
_PATCHES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "patches")
|
||||||
|
if os.path.isdir(_PATCHES_DIR):
|
||||||
|
import subprocess as _sp
|
||||||
|
for _patch in sorted(os.listdir(_PATCHES_DIR)):
|
||||||
|
if not _patch.endswith(".patch"):
|
||||||
|
continue
|
||||||
|
_patch_path = os.path.join(_PATCHES_DIR, _patch)
|
||||||
|
# --check + --reverse: succeeds silently if already applied.
|
||||||
|
if _sp.call(
|
||||||
|
["git", "apply", "--check", "--reverse", _patch_path],
|
||||||
|
cwd=STAR_REPO, stdout=_sp.DEVNULL, stderr=_sp.DEVNULL,
|
||||||
|
) != 0:
|
||||||
|
if _sp.call(["git", "apply", _patch_path], cwd=STAR_REPO) == 0:
|
||||||
|
print(f"[STAR] Applied patch: {_patch}")
|
||||||
|
else:
|
||||||
|
print(f"[STAR] Warning: failed to apply patch: {_patch}")
|
||||||
|
|
||||||
# ── Attention backend dispatcher ──────────────────────────────────────
|
# ── Attention backend dispatcher ──────────────────────────────────────
|
||||||
# Build a registry of available backends at import time.
|
# Build a registry of available backends at import time.
|
||||||
# sdpa (PyTorch native) is always available and is the default.
|
# sdpa (PyTorch native) is always available and is the default.
|
||||||
|
|||||||
24
patches/openclip_batch_first.patch
Normal file
24
patches/openclip_batch_first.patch
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
diff --git a/video_to_video/modules/embedder.py b/video_to_video/modules/embedder.py
|
||||||
|
index 9b2e760..29cc0fd 100644
|
||||||
|
--- a/video_to_video/modules/embedder.py
|
||||||
|
+++ b/video_to_video/modules/embedder.py
|
||||||
|
@@ -54,9 +54,17 @@ class FrozenOpenCLIPEmbedder(nn.Module):
|
||||||
|
def encode_with_transformer(self, text):
|
||||||
|
x = self.model.token_embedding(text)
|
||||||
|
x = x + self.model.positional_embedding
|
||||||
|
- x = x.permute(1, 0, 2)
|
||||||
|
+ # Newer open_clip sets batch_first=True on MHA, so the resblocks
|
||||||
|
+ # expect [batch, seq, embed]. Older versions use batch_first=False
|
||||||
|
+ # and expect [seq, batch, embed]. Only permute for the old layout.
|
||||||
|
+ needs_permute = not getattr(
|
||||||
|
+ self.model.transformer.resblocks[0].attn, "batch_first", False
|
||||||
|
+ )
|
||||||
|
+ if needs_permute:
|
||||||
|
+ x = x.permute(1, 0, 2)
|
||||||
|
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||||
|
- x = x.permute(1, 0, 2)
|
||||||
|
+ if needs_permute:
|
||||||
|
+ x = x.permute(1, 0, 2)
|
||||||
|
x = self.model.ln_final(x)
|
||||||
|
return x
|
||||||
|
|
||||||
Reference in New Issue
Block a user