From d44885252fec174e1ef6c295a18c92f2917a9ed6 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 12 Feb 2026 18:42:23 +0100 Subject: [PATCH] Write detected cupy variant into requirements.txt at install time install.py now updates requirements.txt with the correct cupy-cuda package matching PyTorch's CUDA version before running pip install. Co-Authored-By: Claude Opus 4.6 --- install.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/install.py b/install.py index 35175ee..7b478a2 100644 --- a/install.py +++ b/install.py @@ -14,8 +14,7 @@ def get_cupy_package(): if cuda_version is None: print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.") return None - major = cuda_version.split(".")[0] - major = int(major) + major = int(cuda_version.split(".")[0]) cupy_pkg = f"cupy-cuda{major}x" print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}") return cupy_pkg @@ -24,24 +23,29 @@ def get_cupy_package(): return None +def update_requirements(cupy_pkg): + """Write the correct cupy package into requirements.txt.""" + requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") + lines = [] + if os.path.exists(requirements_path): + with open(requirements_path, "r") as f: + lines = [l.rstrip() for l in f if not l.strip().startswith("cupy")] + if cupy_pkg and cupy_pkg not in lines: + lines.append(cupy_pkg) + with open(requirements_path, "w") as f: + f.write("\n".join(lines) + "\n") + + def install(): + cupy_pkg = get_cupy_package() + if cupy_pkg: + update_requirements(cupy_pkg) + requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") subprocess.check_call([ sys.executable, "-m", "pip", "install", "-r", requirements_path ]) - # Install cupy matching the current CUDA version - try: - import cupy - print("[BIM-VFI] cupy already installed, skipping.") - except ImportError: - cupy_pkg = get_cupy_package() - if cupy_pkg: - print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...") - subprocess.check_call([ - sys.executable, "-m", "pip", "install", cupy_pkg - ]) - if __name__ == "__main__": install()