Write detected cupy variant into requirements.txt at install time

install.py now updates requirements.txt with the correct cupy-cuda
package matching PyTorch's CUDA version before running pip install.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 18:42:23 +01:00
parent 471a299027
commit d44885252f

View File

@@ -14,8 +14,7 @@ def get_cupy_package():
if cuda_version is None: if cuda_version is None:
print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.") print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.")
return None return None
major = cuda_version.split(".")[0] major = int(cuda_version.split(".")[0])
major = int(major)
cupy_pkg = f"cupy-cuda{major}x" cupy_pkg = f"cupy-cuda{major}x"
print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}") print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}")
return cupy_pkg return cupy_pkg
@@ -24,24 +23,29 @@ def get_cupy_package():
return None return None
def update_requirements(cupy_pkg):
"""Write the correct cupy package into requirements.txt."""
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
lines = []
if os.path.exists(requirements_path):
with open(requirements_path, "r") as f:
lines = [l.rstrip() for l in f if not l.strip().startswith("cupy")]
if cupy_pkg and cupy_pkg not in lines:
lines.append(cupy_pkg)
with open(requirements_path, "w") as f:
f.write("\n".join(lines) + "\n")
def install(): def install():
cupy_pkg = get_cupy_package()
if cupy_pkg:
update_requirements(cupy_pkg)
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
subprocess.check_call([ subprocess.check_call([
sys.executable, "-m", "pip", "install", "-r", requirements_path sys.executable, "-m", "pip", "install", "-r", requirements_path
]) ])
# Install cupy matching the current CUDA version
try:
import cupy
print("[BIM-VFI] cupy already installed, skipping.")
except ImportError:
cupy_pkg = get_cupy_package()
if cupy_pkg:
print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...")
subprocess.check_call([
sys.executable, "-m", "pip", "install", cupy_pkg
])
if __name__ == "__main__": if __name__ == "__main__":
install() install()