Add pure-PyTorch fallback for SGM-VFI softsplat forward warp

Make cupy import optional so the module loads without cupy installed.
Replace @cupy.memoize decorator with a simple dict cache to avoid
crash at import time. Add _pytorch_softsplat() using scatter_add_
as a fallback when cupy is unavailable or tensors are on CPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-11 01:59:23 +02:00
parent 91947c0b8c
commit 8d8407ec9d
+62 -9
View File
@@ -1,7 +1,10 @@
#!/usr/bin/env python
import collections
import cupy
try:
import cupy
except ImportError:
cupy = None
import os
import re
import torch
@@ -216,20 +219,70 @@ def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
# end
@cupy.memoize(for_each_device=True)
_cuda_launch_cache = {}
def cuda_launch(strKey:str):
if strKey not in _cuda_launch_cache:
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
# end
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],
options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
_cuda_launch_cache[strKey] = cupy.RawKernel(
objCudacache[strKey]['strKernel'],
objCudacache[strKey]['strFunction'],
options=tuple(['-I ' + os.environ['CUDA_HOME'],
'-I ' + os.environ['CUDA_HOME'] + '/include'])
)
return _cuda_launch_cache[strKey]
# end
##########################################################
def _pytorch_softsplat(tenIn, tenFlow):
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
B, C, H, W = tenIn.shape
tenOut = tenIn.new_zeros(B, C, H, W)
grid_y, grid_x = torch.meshgrid(
torch.arange(H, device=tenIn.device, dtype=tenIn.dtype),
torch.arange(W, device=tenIn.device, dtype=tenIn.dtype),
indexing='ij',
)
flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :]
flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :]
valid = torch.isfinite(flt_x) & torch.isfinite(flt_y)
flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x))
flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y))
nw_x = flt_x.floor().long()
nw_y = flt_y.floor().long()
frac_x = flt_x - nw_x.float()
frac_y = flt_y - nw_y.float()
w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid
w_ne = frac_x * (1.0 - frac_y) * valid
w_sw = (1.0 - frac_x) * frac_y * valid
w_se = frac_x * frac_y * valid
out_flat = tenOut.view(B, C, -1)
in_flat = tenIn
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
tx = nw_x + dx
ty = nw_y + dy
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
w_masked = w * in_bounds
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1))
idx = idx.unsqueeze(1).expand_as(in_flat)
weighted = in_flat * w_masked.unsqueeze(1)
out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
return tenOut
# end
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
@@ -281,7 +334,7 @@ class softsplat_func(torch.autograd.Function):
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
if tenIn.is_cuda == True:
if tenIn.is_cuda and cupy is not None:
cuda_launch(cuda_kernel('softsplat_out', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
const int n,
@@ -345,8 +398,8 @@ class softsplat_func(torch.autograd.Function):
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
elif tenIn.is_cuda != True:
assert(False)
else:
tenOut = _pytorch_softsplat(tenIn, tenFlow)
# end