From 8d8407ec9d935aeae59175ede7c8f01409bb6742 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 11 Apr 2026 01:59:23 +0200 Subject: [PATCH] Add pure-PyTorch fallback for SGM-VFI softsplat forward warp Make cupy import optional so the module loads without cupy installed. Replace @cupy.memoize decorator with a simple dict cache to avoid crash at import time. Add _pytorch_softsplat() using scatter_add_ as a fallback when cupy is unavailable or tensors are on CPU. Co-Authored-By: Claude Opus 4.6 --- sgm_vfi_arch/softsplat.py | 75 +++++++++++++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 11 deletions(-) diff --git a/sgm_vfi_arch/softsplat.py b/sgm_vfi_arch/softsplat.py index eeccb88..a34b933 100644 --- a/sgm_vfi_arch/softsplat.py +++ b/sgm_vfi_arch/softsplat.py @@ -1,7 +1,10 @@ #!/usr/bin/env python import collections -import cupy +try: + import cupy +except ImportError: + cupy = None import os import re import torch @@ -216,20 +219,70 @@ def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): # end -@cupy.memoize(for_each_device=True) -def cuda_launch(strKey:str): - if 'CUDA_HOME' not in os.environ: - os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() - # end +_cuda_launch_cache = {} - return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'], - options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])) +def cuda_launch(strKey:str): + if strKey not in _cuda_launch_cache: + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + _cuda_launch_cache[strKey] = cupy.RawKernel( + objCudacache[strKey]['strKernel'], + objCudacache[strKey]['strFunction'], + options=tuple(['-I ' + os.environ['CUDA_HOME'], + '-I ' + os.environ['CUDA_HOME'] + '/include']) + ) + return _cuda_launch_cache[strKey] # end ########################################################## +def _pytorch_softsplat(tenIn, tenFlow): + """Pure-PyTorch forward warp via bilinear splatting (scatter_add).""" + B, C, H, W = tenIn.shape + tenOut = tenIn.new_zeros(B, C, H, W) + + grid_y, grid_x = torch.meshgrid( + torch.arange(H, device=tenIn.device, dtype=tenIn.dtype), + torch.arange(W, device=tenIn.device, dtype=tenIn.dtype), + indexing='ij', + ) + + flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :] + flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :] + + valid = torch.isfinite(flt_x) & torch.isfinite(flt_y) + flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x)) + flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y)) + + nw_x = flt_x.floor().long() + nw_y = flt_y.floor().long() + frac_x = flt_x - nw_x.float() + frac_y = flt_y - nw_y.float() + + w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid + w_ne = frac_x * (1.0 - frac_y) * valid + w_sw = (1.0 - frac_x) * frac_y * valid + w_se = frac_x * frac_y * valid + + out_flat = tenOut.view(B, C, -1) + in_flat = tenIn + + for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]: + tx = nw_x + dx + ty = nw_y + dy + in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H) + w_masked = w * in_bounds + idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1)) + idx = idx.unsqueeze(1).expand_as(in_flat) + weighted = in_flat * w_masked.unsqueeze(1) + out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1)) + + return tenOut +# end + + def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) @@ -281,7 +334,7 @@ class softsplat_func(torch.autograd.Function): def forward(self, tenIn, tenFlow): tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) - if tenIn.is_cuda == True: + if tenIn.is_cuda and cupy is not None: cuda_launch(cuda_kernel('softsplat_out', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_out( const int n, @@ -345,8 +398,8 @@ class softsplat_func(torch.autograd.Function): stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) - elif tenIn.is_cuda != True: - assert(False) + else: + tenOut = _pytorch_softsplat(tenIn, tenFlow) # end