Add pure-PyTorch fallback for SGM-VFI softsplat forward warp
Make cupy import optional so the module loads without cupy installed. Replace @cupy.memoize decorator with a simple dict cache to avoid crash at import time. Add _pytorch_softsplat() using scatter_add_ as a fallback when cupy is unavailable or tensors are on CPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+64
-11
@@ -1,7 +1,10 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import cupy
|
try:
|
||||||
|
import cupy
|
||||||
|
except ImportError:
|
||||||
|
cupy = None
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
@@ -216,20 +219,70 @@ def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
@cupy.memoize(for_each_device=True)
|
_cuda_launch_cache = {}
|
||||||
def cuda_launch(strKey:str):
|
|
||||||
if 'CUDA_HOME' not in os.environ:
|
|
||||||
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
|
||||||
# end
|
|
||||||
|
|
||||||
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],
|
def cuda_launch(strKey:str):
|
||||||
options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
|
if strKey not in _cuda_launch_cache:
|
||||||
|
if 'CUDA_HOME' not in os.environ:
|
||||||
|
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
||||||
|
_cuda_launch_cache[strKey] = cupy.RawKernel(
|
||||||
|
objCudacache[strKey]['strKernel'],
|
||||||
|
objCudacache[strKey]['strFunction'],
|
||||||
|
options=tuple(['-I ' + os.environ['CUDA_HOME'],
|
||||||
|
'-I ' + os.environ['CUDA_HOME'] + '/include'])
|
||||||
|
)
|
||||||
|
return _cuda_launch_cache[strKey]
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _pytorch_softsplat(tenIn, tenFlow):
|
||||||
|
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
|
||||||
|
B, C, H, W = tenIn.shape
|
||||||
|
tenOut = tenIn.new_zeros(B, C, H, W)
|
||||||
|
|
||||||
|
grid_y, grid_x = torch.meshgrid(
|
||||||
|
torch.arange(H, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
torch.arange(W, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
|
||||||
|
flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :]
|
||||||
|
flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :]
|
||||||
|
|
||||||
|
valid = torch.isfinite(flt_x) & torch.isfinite(flt_y)
|
||||||
|
flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x))
|
||||||
|
flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y))
|
||||||
|
|
||||||
|
nw_x = flt_x.floor().long()
|
||||||
|
nw_y = flt_y.floor().long()
|
||||||
|
frac_x = flt_x - nw_x.float()
|
||||||
|
frac_y = flt_y - nw_y.float()
|
||||||
|
|
||||||
|
w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid
|
||||||
|
w_ne = frac_x * (1.0 - frac_y) * valid
|
||||||
|
w_sw = (1.0 - frac_x) * frac_y * valid
|
||||||
|
w_se = frac_x * frac_y * valid
|
||||||
|
|
||||||
|
out_flat = tenOut.view(B, C, -1)
|
||||||
|
in_flat = tenIn
|
||||||
|
|
||||||
|
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
|
||||||
|
tx = nw_x + dx
|
||||||
|
ty = nw_y + dy
|
||||||
|
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
|
||||||
|
w_masked = w * in_bounds
|
||||||
|
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1))
|
||||||
|
idx = idx.unsqueeze(1).expand_as(in_flat)
|
||||||
|
weighted = in_flat * w_masked.unsqueeze(1)
|
||||||
|
out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
|
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
|
||||||
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
||||||
|
|
||||||
@@ -281,7 +334,7 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
def forward(self, tenIn, tenFlow):
|
def forward(self, tenIn, tenFlow):
|
||||||
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
||||||
|
|
||||||
if tenIn.is_cuda == True:
|
if tenIn.is_cuda and cupy is not None:
|
||||||
cuda_launch(cuda_kernel('softsplat_out', '''
|
cuda_launch(cuda_kernel('softsplat_out', '''
|
||||||
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
||||||
const int n,
|
const int n,
|
||||||
@@ -345,8 +398,8 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tenIn.is_cuda != True:
|
else:
|
||||||
assert(False)
|
tenOut = _pytorch_softsplat(tenIn, tenFlow)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user