perf: add torch.compile to PyTorch fallback kernels

Wraps _pytorch_softsplat and _pytorch_costvol with torch.compile
for ~6x speedup on ROCm/non-cupy setups. Falls back to eager
execution gracefully if compilation fails.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-11 10:13:30 +02:00
parent 2e75e2d076
commit 83e4b5dd98
3 changed files with 53 additions and 6 deletions
@@ -292,7 +292,7 @@ def cuda_launch(strKey: str):
##########################################################
def _pytorch_softsplat(tenIn, tenFlow):
def _pytorch_softsplat_impl(tenIn, tenFlow):
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
B, C, H, W = tenIn.shape
tenOut = tenIn.new_zeros(B, C, H, W)
@@ -335,6 +335,22 @@ def _pytorch_softsplat(tenIn, tenFlow):
return tenOut
_softsplat_fn = None
def _pytorch_softsplat(tenIn, tenFlow):
global _softsplat_fn
if _softsplat_fn is None:
try:
_softsplat_fn = torch.compile(_pytorch_softsplat_impl)
except Exception:
_softsplat_fn = _pytorch_softsplat_impl
try:
return _softsplat_fn(tenIn, tenFlow)
except Exception:
_softsplat_fn = _pytorch_softsplat_impl
return _softsplat_fn(tenIn, tenFlow)
@torch.compiler.disable()
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]