diff --git a/install.py b/install.py index fe28047..5f3510d 100644 --- a/install.py +++ b/install.py @@ -8,44 +8,39 @@ def get_cupy_package(): try: import torch if not torch.cuda.is_available(): - print("[Tween] WARNING: CUDA not available. cupy requires CUDA.") return None cuda_version = torch.version.cuda if cuda_version is None: - print("[Tween] WARNING: PyTorch has no CUDA version info.") return None major = int(cuda_version.split(".")[0]) cupy_pkg = f"cupy-cuda{major}x" - print(f"[Tween] Detected CUDA {cuda_version}, will use {cupy_pkg}") return cupy_pkg - except Exception as e: - print(f"[Tween] WARNING: Could not detect CUDA version: {e}") + except Exception: return None -def update_requirements(cupy_pkg): - """Write the correct cupy package into requirements.txt.""" - requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") - lines = [] - if os.path.exists(requirements_path): - with open(requirements_path, "r") as f: - lines = [l.rstrip() for l in f if not l.strip().startswith("cupy")] - if cupy_pkg and cupy_pkg not in lines: - lines.append(cupy_pkg) - with open(requirements_path, "w") as f: - f.write("\n".join(lines) + "\n") - - def install(): - cupy_pkg = get_cupy_package() - if cupy_pkg: - update_requirements(cupy_pkg) - + # Install core requirements first requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") subprocess.check_call([ sys.executable, "-m", "pip", "install", "-r", requirements_path ]) + # Try to install cupy for NVIDIA users (optional, improves performance) + cupy_pkg = get_cupy_package() + if cupy_pkg: + try: + subprocess.check_call([ + sys.executable, "-m", "pip", "install", cupy_pkg + ]) + print(f"[Tween] cupy installed ({cupy_pkg}) — fast CUDA kernels enabled") + except subprocess.CalledProcessError: + print(f"[Tween] WARNING: Could not install {cupy_pkg}. " + f"BIM-VFI, SGM-VFI, and GIMM-VFI will use slower PyTorch fallback.") + else: + print("[Tween] cupy skipped (no NVIDIA CUDA). " + "BIM-VFI, SGM-VFI, and GIMM-VFI will use PyTorch fallback.") + if __name__ == "__main__": install()