diff --git a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py index f8139d3..60eaec1 100644 --- a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py +++ b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py @@ -268,11 +268,14 @@ _cuda_launch_cache = {} @torch.compiler.disable() def cuda_launch(strKey: str): if strKey not in _cuda_launch_cache: - try: - os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path()) - except Exception: - if "CUDA_HOME" not in os.environ: - raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.") + if "CUDA_HOME" not in os.environ: + try: + cuda_path = cupy.cuda.get_cuda_path() + except Exception: + cuda_path = None + if cuda_path is None: + cuda_path = "/usr/local/cuda" + os.environ["CUDA_HOME"] = cuda_path strKernel = objCudacache[strKey]["strKernel"] strFunction = objCudacache[strKey]["strFunction"] _cuda_launch_cache[strKey] = cupy.RawModule( diff --git a/sgm_vfi_arch/softsplat.py b/sgm_vfi_arch/softsplat.py index f01f5b0..ad2d293 100644 --- a/sgm_vfi_arch/softsplat.py +++ b/sgm_vfi_arch/softsplat.py @@ -224,7 +224,13 @@ _cuda_launch_cache = {} def cuda_launch(strKey:str): if strKey not in _cuda_launch_cache: if 'CUDA_HOME' not in os.environ: - os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + try: + cuda_path = cupy.cuda.get_cuda_path() + except Exception: + cuda_path = None + if cuda_path is None: + cuda_path = '/usr/local/cuda' + os.environ['CUDA_HOME'] = cuda_path _cuda_launch_cache[strKey] = cupy.RawKernel( objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],