Add pure-PyTorch fallback for GIMM-VFI softsplat forward warp
Make cupy import optional (try/except), replace @cupy.memoize with a dict cache, add _pytorch_softsplat() using scatter_add for bilinear splatting, and update forward() dispatch to fall back to PyTorch when cupy is unavailable or tensor is on CPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -9,7 +9,10 @@
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import cupy
|
try:
|
||||||
|
import cupy
|
||||||
|
except ImportError:
|
||||||
|
cupy = None
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
@@ -260,31 +263,75 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
@cupy.memoize(for_each_device=True)
|
_cuda_launch_cache = {}
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def cuda_launch(strKey: str):
|
def cuda_launch(strKey: str):
|
||||||
try:
|
if strKey not in _cuda_launch_cache:
|
||||||
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
|
try:
|
||||||
except Exception:
|
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
|
||||||
if "CUDA_HOME" not in os.environ:
|
except Exception:
|
||||||
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
|
if "CUDA_HOME" not in os.environ:
|
||||||
|
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
|
||||||
strKernel = objCudacache[strKey]["strKernel"]
|
strKernel = objCudacache[strKey]["strKernel"]
|
||||||
strFunction = objCudacache[strKey]["strFunction"]
|
strFunction = objCudacache[strKey]["strFunction"]
|
||||||
|
_cuda_launch_cache[strKey] = cupy.RawModule(
|
||||||
return cupy.RawModule(
|
code=strKernel,
|
||||||
code=strKernel,
|
options=(
|
||||||
options=(
|
"-I " + os.environ["CUDA_HOME"],
|
||||||
"-I " + os.environ["CUDA_HOME"],
|
"-I " + os.environ["CUDA_HOME"] + "/include",
|
||||||
"-I " + os.environ["CUDA_HOME"] + "/include",
|
),
|
||||||
),
|
).get_function(strFunction)
|
||||||
).get_function(strFunction)
|
return _cuda_launch_cache[strKey]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _pytorch_softsplat(tenIn, tenFlow):
|
||||||
|
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
|
||||||
|
B, C, H, W = tenIn.shape
|
||||||
|
tenOut = tenIn.new_zeros(B, C, H, W)
|
||||||
|
|
||||||
|
grid_y, grid_x = torch.meshgrid(
|
||||||
|
torch.arange(H, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
torch.arange(W, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
|
||||||
|
flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :]
|
||||||
|
flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :]
|
||||||
|
|
||||||
|
valid = torch.isfinite(flt_x) & torch.isfinite(flt_y)
|
||||||
|
flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x))
|
||||||
|
flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y))
|
||||||
|
|
||||||
|
nw_x = flt_x.floor().long()
|
||||||
|
nw_y = flt_y.floor().long()
|
||||||
|
frac_x = flt_x - nw_x.to(flt_x.dtype)
|
||||||
|
frac_y = flt_y - nw_y.to(flt_y.dtype)
|
||||||
|
|
||||||
|
w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid
|
||||||
|
w_ne = frac_x * (1.0 - frac_y) * valid
|
||||||
|
w_sw = (1.0 - frac_x) * frac_y * valid
|
||||||
|
w_se = frac_x * frac_y * valid
|
||||||
|
|
||||||
|
out_flat = tenOut.view(B, C, -1)
|
||||||
|
|
||||||
|
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
|
||||||
|
tx = nw_x + dx
|
||||||
|
ty = nw_y + dy
|
||||||
|
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
|
||||||
|
w_masked = w * in_bounds
|
||||||
|
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1))
|
||||||
|
idx = idx.unsqueeze(1).expand_as(tenIn)
|
||||||
|
weighted = tenIn * w_masked.unsqueeze(1)
|
||||||
|
out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
|
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
|
||||||
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
|
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
|
||||||
@@ -366,7 +413,7 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
||||||
)
|
)
|
||||||
|
|
||||||
if tenIn.is_cuda == True:
|
if tenIn.is_cuda and cupy is not None:
|
||||||
cuda_launch(
|
cuda_launch(
|
||||||
cuda_kernel(
|
cuda_kernel(
|
||||||
"softsplat_out",
|
"softsplat_out",
|
||||||
@@ -439,8 +486,8 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tenIn.is_cuda != True:
|
else:
|
||||||
assert False
|
tenOut = _pytorch_softsplat(tenIn, tenFlow)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user