diff --git a/inference.py b/inference.py index afa2d38..a65321a 100755 --- a/inference.py +++ b/inference.py @@ -29,8 +29,26 @@ import types from pathlib import Path SCRIPT_DIR = Path(__file__).resolve().parent +STAR_REPO = SCRIPT_DIR / "STAR" sys.path.insert(0, str(SCRIPT_DIR)) -sys.path.insert(0, str(SCRIPT_DIR / "STAR")) +sys.path.insert(0, str(STAR_REPO)) + +# Apply patches from patches/ directory to the STAR submodule. +import subprocess # noqa: E402 + +_PATCHES_DIR = SCRIPT_DIR / "patches" +if _PATCHES_DIR.is_dir(): + for _patch in sorted(_PATCHES_DIR.iterdir()): + if _patch.suffix != ".patch": + continue + if subprocess.call( + ["git", "apply", "--check", "--reverse", str(_patch)], + cwd=str(STAR_REPO), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) != 0: + if subprocess.call(["git", "apply", str(_patch)], cwd=str(STAR_REPO)) == 0: + print(f"[STAR] Applied patch: {_patch.name}") + else: + print(f"[STAR] Warning: failed to apply patch: {_patch.name}") import torch # noqa: E402 — needed for stub defaults @@ -138,7 +156,6 @@ print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}") import argparse # noqa: E402 import json # noqa: E402 import shutil # noqa: E402 -import subprocess # noqa: E402 import numpy as np # noqa: E402 from PIL import Image # noqa: E402 diff --git a/nodes.py b/nodes.py index 1bf579e..ba4ce0e 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,24 @@ if not os.path.isdir(os.path.join(STAR_REPO, "video_to_video")): if STAR_REPO not in sys.path: sys.path.insert(0, STAR_REPO) +# Apply patches from patches/ directory to the STAR submodule. +_PATCHES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "patches") +if os.path.isdir(_PATCHES_DIR): + import subprocess as _sp + for _patch in sorted(os.listdir(_PATCHES_DIR)): + if not _patch.endswith(".patch"): + continue + _patch_path = os.path.join(_PATCHES_DIR, _patch) + # --check + --reverse: succeeds silently if already applied. + if _sp.call( + ["git", "apply", "--check", "--reverse", _patch_path], + cwd=STAR_REPO, stdout=_sp.DEVNULL, stderr=_sp.DEVNULL, + ) != 0: + if _sp.call(["git", "apply", _patch_path], cwd=STAR_REPO) == 0: + print(f"[STAR] Applied patch: {_patch}") + else: + print(f"[STAR] Warning: failed to apply patch: {_patch}") + # ── Attention backend dispatcher ────────────────────────────────────── # Build a registry of available backends at import time. # sdpa (PyTorch native) is always available and is the default. diff --git a/patches/openclip_batch_first.patch b/patches/openclip_batch_first.patch new file mode 100644 index 0000000..25ce402 --- /dev/null +++ b/patches/openclip_batch_first.patch @@ -0,0 +1,24 @@ +diff --git a/video_to_video/modules/embedder.py b/video_to_video/modules/embedder.py +index 9b2e760..29cc0fd 100644 +--- a/video_to_video/modules/embedder.py ++++ b/video_to_video/modules/embedder.py +@@ -54,9 +54,17 @@ class FrozenOpenCLIPEmbedder(nn.Module): + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) + x = x + self.model.positional_embedding +- x = x.permute(1, 0, 2) ++ # Newer open_clip sets batch_first=True on MHA, so the resblocks ++ # expect [batch, seq, embed]. Older versions use batch_first=False ++ # and expect [seq, batch, embed]. Only permute for the old layout. ++ needs_permute = not getattr( ++ self.model.transformer.resblocks[0].attn, "batch_first", False ++ ) ++ if needs_permute: ++ x = x.permute(1, 0, 2) + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) +- x = x.permute(1, 0, 2) ++ if needs_permute: ++ x = x.permute(1, 0, 2) + x = self.model.ln_final(x) + return x +