Write detected cupy variant into requirements.txt at install time
install.py now updates requirements.txt with the correct cupy-cuda package matching PyTorch's CUDA version before running pip install. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
32
install.py
32
install.py
@@ -14,8 +14,7 @@ def get_cupy_package():
|
|||||||
if cuda_version is None:
|
if cuda_version is None:
|
||||||
print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.")
|
print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.")
|
||||||
return None
|
return None
|
||||||
major = cuda_version.split(".")[0]
|
major = int(cuda_version.split(".")[0])
|
||||||
major = int(major)
|
|
||||||
cupy_pkg = f"cupy-cuda{major}x"
|
cupy_pkg = f"cupy-cuda{major}x"
|
||||||
print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}")
|
print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}")
|
||||||
return cupy_pkg
|
return cupy_pkg
|
||||||
@@ -24,24 +23,29 @@ def get_cupy_package():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def update_requirements(cupy_pkg):
|
||||||
|
"""Write the correct cupy package into requirements.txt."""
|
||||||
|
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||||
|
lines = []
|
||||||
|
if os.path.exists(requirements_path):
|
||||||
|
with open(requirements_path, "r") as f:
|
||||||
|
lines = [l.rstrip() for l in f if not l.strip().startswith("cupy")]
|
||||||
|
if cupy_pkg and cupy_pkg not in lines:
|
||||||
|
lines.append(cupy_pkg)
|
||||||
|
with open(requirements_path, "w") as f:
|
||||||
|
f.write("\n".join(lines) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def install():
|
def install():
|
||||||
|
cupy_pkg = get_cupy_package()
|
||||||
|
if cupy_pkg:
|
||||||
|
update_requirements(cupy_pkg)
|
||||||
|
|
||||||
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||||
subprocess.check_call([
|
subprocess.check_call([
|
||||||
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
||||||
])
|
])
|
||||||
|
|
||||||
# Install cupy matching the current CUDA version
|
|
||||||
try:
|
|
||||||
import cupy
|
|
||||||
print("[BIM-VFI] cupy already installed, skipping.")
|
|
||||||
except ImportError:
|
|
||||||
cupy_pkg = get_cupy_package()
|
|
||||||
if cupy_pkg:
|
|
||||||
print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...")
|
|
||||||
subprocess.check_call([
|
|
||||||
sys.executable, "-m", "pip", "install", cupy_pkg
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
install()
|
install()
|
||||||
|
|||||||
Reference in New Issue
Block a user