diff --git a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py index 415fc51..f8139d3 100644 --- a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py +++ b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py @@ -9,7 +9,10 @@ # -------------------------------------------------------- import collections -import cupy +try: + import cupy +except ImportError: + cupy = None import os import re import torch @@ -260,31 +263,75 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): # end -@cupy.memoize(for_each_device=True) +_cuda_launch_cache = {} + @torch.compiler.disable() def cuda_launch(strKey: str): - try: - os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path()) - except Exception: - if "CUDA_HOME" not in os.environ: - raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.") - - strKernel = objCudacache[strKey]["strKernel"] - strFunction = objCudacache[strKey]["strFunction"] - - return cupy.RawModule( - code=strKernel, - options=( - "-I " + os.environ["CUDA_HOME"], - "-I " + os.environ["CUDA_HOME"] + "/include", - ), - ).get_function(strFunction) + if strKey not in _cuda_launch_cache: + try: + os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path()) + except Exception: + if "CUDA_HOME" not in os.environ: + raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.") + strKernel = objCudacache[strKey]["strKernel"] + strFunction = objCudacache[strKey]["strFunction"] + _cuda_launch_cache[strKey] = cupy.RawModule( + code=strKernel, + options=( + "-I " + os.environ["CUDA_HOME"], + "-I " + os.environ["CUDA_HOME"] + "/include", + ), + ).get_function(strFunction) + return _cuda_launch_cache[strKey] ########################################################## +def _pytorch_softsplat(tenIn, tenFlow): + """Pure-PyTorch forward warp via bilinear splatting (scatter_add).""" + B, C, H, W = tenIn.shape + tenOut = tenIn.new_zeros(B, C, H, W) + + grid_y, grid_x = torch.meshgrid( + torch.arange(H, device=tenIn.device, dtype=tenIn.dtype), + torch.arange(W, device=tenIn.device, dtype=tenIn.dtype), + indexing='ij', + ) + + flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :] + flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :] + + valid = torch.isfinite(flt_x) & torch.isfinite(flt_y) + flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x)) + flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y)) + + nw_x = flt_x.floor().long() + nw_y = flt_y.floor().long() + frac_x = flt_x - nw_x.to(flt_x.dtype) + frac_y = flt_y - nw_y.to(flt_y.dtype) + + w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid + w_ne = frac_x * (1.0 - frac_y) * valid + w_sw = (1.0 - frac_x) * frac_y * valid + w_se = frac_x * frac_y * valid + + out_flat = tenOut.view(B, C, -1) + + for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]: + tx = nw_x + dx + ty = nw_y + dy + in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H) + w_masked = w * in_bounds + idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1)) + idx = idx.unsqueeze(1).expand_as(tenIn) + weighted = tenIn * w_masked.unsqueeze(1) + out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1)) + + return tenOut + + @torch.compiler.disable() def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False): assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"] @@ -366,7 +413,7 @@ class softsplat_func(torch.autograd.Function): [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] ) - if tenIn.is_cuda == True: + if tenIn.is_cuda and cupy is not None: cuda_launch( cuda_kernel( "softsplat_out", @@ -439,8 +486,8 @@ class softsplat_func(torch.autograd.Function): ), ) - elif tenIn.is_cuda != True: - assert False + else: + tenOut = _pytorch_softsplat(tenIn, tenFlow) # end