diff --git a/bim_vfi_arch/costvol.py b/bim_vfi_arch/costvol.py index e094dc3..931c6d0 100644 --- a/bim_vfi_arch/costvol.py +++ b/bim_vfi_arch/costvol.py @@ -1,12 +1,26 @@ #!/usr/bin/env python import collections -import cupy import os import re import torch import typing +cupy = None + +def _ensure_cupy(): + global cupy + if cupy is None: + try: + import cupy as _cupy + cupy = _cupy + except ImportError: + raise RuntimeError( + "cupy is required for BIM-VFI. Install it with:\n" + " pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n" + "Or run install.py from the Comfyui-BIM-VFI directory." + ) + ########################################################## @@ -15,16 +29,19 @@ objCudacache = {} def cuda_int32(intIn:int): + _ensure_cupy() return cupy.int32(intIn) # end def cuda_float32(fltIn:float): + _ensure_cupy() return cupy.float32(fltIn) # end def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + _ensure_cupy() if 'device' not in objCudacache: objCudacache['device'] = torch.cuda.get_device_name() # end @@ -216,13 +233,16 @@ def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): # end -@cupy.memoize(for_each_device=True) -def cuda_launch(strKey:str): - if 'CUDA_HOME' not in os.environ: - os.environ['CUDA_HOME'] = '/usr/local/cuda/' - # end +_cuda_launch_cache = {} - return cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction']) +def cuda_launch(strKey:str): + _ensure_cupy() + if strKey not in _cuda_launch_cache: + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + _cuda_launch_cache[strKey] = cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction']) + return _cuda_launch_cache[strKey] # end