perf: add torch.compile to PyTorch fallback kernels

Wraps _pytorch_softsplat and _pytorch_costvol with torch.compile
for ~6x speedup on ROCm/non-cupy setups. Falls back to eager
execution gracefully if compilation fails.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-11 10:13:30 +02:00
parent 2e75e2d076
commit 83e4b5dd98
3 changed files with 53 additions and 6 deletions
+17 -1
View File
@@ -243,7 +243,7 @@ def cuda_launch(strKey:str):
# end
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
def _pytorch_costvol_impl(tenOne, tenTwo, intKernelSize):
"""Pure-PyTorch local cost volume via unfold + dot product."""
B, C, H, W = tenOne.shape
pad = (intKernelSize - 1) // 2
@@ -265,6 +265,22 @@ def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
return tenOut
_costvol_fn = None
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
global _costvol_fn
if _costvol_fn is None:
try:
_costvol_fn = torch.compile(_pytorch_costvol_impl)
except Exception:
_costvol_fn = _pytorch_costvol_impl
try:
return _costvol_fn(tenOne, tenTwo, intKernelSize)
except Exception:
_costvol_fn = _pytorch_costvol_impl
return _costvol_fn(tenOne, tenTwo, intKernelSize)
##########################################################